[TREE] add interaction constraints (#3466)

* add interaction constraints

* enable both interaction and monotonic constraints at the same time

* fix lint

* add R test, fix lint, update demo

* Use dmlc::JSONReader to express interaction constraints as nested lists; Use sparse arrays for bookkeeping

* Add Python test for interaction constraints

* make R interaction constraints parameter based on feature index instead of column names, fix R coding style

* Fix lint

* Add BlueTea88 to CONTRIBUTORS.md

* Short circuit when no constraint is specified; address review comments

* Add tutorial for feature interaction constraints

* allow interaction constraints to be passed as string, remove redundant column_names argument

* Fix typo

* Address review comments

* Add comments to Python test
This commit is contained in:
Andrew Thia 2018-09-05 02:35:39 +10:00 committed by Philip Hyunsu Cho
parent dee0b69674
commit 9254c58e4d
12 changed files with 581 additions and 3 deletions

View File

@ -78,3 +78,5 @@ List of Contributors
* [Pierre de Sahb](https://github.com/pdesahb)
* [liuliang01](https://github.com/liuliang01)
- liuliang01 added support for the qid column for LibSVM input format. This makes ranking task easier in distributed setting.
* [Andrew Thia](https://github.com/BlueTea88)
- Andrew Thia implemented feature interaction constraints

View File

@ -74,6 +74,19 @@ check.booster.params <- function(params, ...) {
params[['monotone_constraints']] = vec2str
}
# interaction constraints parser (convert from list of column indices to string)
if (!is.null(params[['interaction_constraints']]) &&
typeof(params[['interaction_constraints']]) != "character"){
# check input class
if (class(params[['interaction_constraints']]) != 'list') stop('interaction_constraints should be class list')
if (!all(unique(sapply(params[['interaction_constraints']], class)) %in% c('numeric','integer'))) {
stop('interaction_constraints should be a list of numeric/integer vectors')
}
# recast parameter as string
interaction_constraints <- sapply(params[['interaction_constraints']], function(x) paste0('[', paste(x, collapse=','), ']'))
params[['interaction_constraints']] <- paste0('[', paste(interaction_constraints, collapse=','), ']')
}
return(params)
}

View File

@ -26,6 +26,7 @@
#' \item \code{colsample_bytree} subsample ratio of columns when constructing each tree. Default: 1
#' \item \code{num_parallel_tree} Experimental parameter. number of trees to grow per round. Useful to test Random Forest through Xgboost (set \code{colsample_bytree < 1}, \code{subsample < 1} and \code{round = 1}) accordingly. Default: 1
#' \item \code{monotone_constraints} A numerical vector consists of \code{1}, \code{0} and \code{-1} with its length equals to the number of features in the training data. \code{1} is increasing, \code{-1} is decreasing and \code{0} is no constraint.
#' \item \code{interaction_constraints} A list of vectors specifying feature indices of permitted interactions. Each item of the list represents one permitted interaction where specified features are allowed to interact with each other. Feature index values should start from \code{0} (\code{0} references the first column). Leave argument unspecified for no interaction constraints.
#' }
#'
#' 2.2. Parameter for Linear Booster

View File

@ -0,0 +1,105 @@
library(xgboost)
library(data.table)
set.seed(1024)
# Function to obtain a list of interactions fitted in trees, requires input of maximum depth
treeInteractions <- function(input_tree, input_max_depth){
trees <- copy(input_tree) # copy tree input to prevent overwriting
if (input_max_depth < 2) return(list()) # no interactions if max depth < 2
if (nrow(input_tree) == 1) return(list())
# Attach parent nodes
for (i in 2:input_max_depth){
if (i == 2) trees[, ID_merge:=ID] else trees[, ID_merge:=get(paste0('parent_',i-2))]
parents_left <- trees[!is.na(Split), list(i.id=ID, i.feature=Feature, ID_merge=Yes)]
parents_right <- trees[!is.na(Split), list(i.id=ID, i.feature=Feature, ID_merge=No)]
setorderv(trees, 'ID_merge')
setorderv(parents_left, 'ID_merge')
setorderv(parents_right, 'ID_merge')
trees <- merge(trees, parents_left, by='ID_merge', all.x=T)
trees[!is.na(i.id), c(paste0('parent_', i-1), paste0('parent_feat_', i-1)):=list(i.id, i.feature)]
trees[, c('i.id','i.feature'):=NULL]
trees <- merge(trees, parents_right, by='ID_merge', all.x=T)
trees[!is.na(i.id), c(paste0('parent_', i-1), paste0('parent_feat_', i-1)):=list(i.id, i.feature)]
trees[, c('i.id','i.feature'):=NULL]
}
# Extract nodes with interactions
interaction_trees <- trees[!is.na(Split) & !is.na(parent_1),
c('Feature',paste0('parent_feat_',1:(input_max_depth-1))), with=F]
interaction_trees_split <- split(interaction_trees, 1:nrow(interaction_trees))
interaction_list <- lapply(interaction_trees_split, as.character)
# Remove NAs (no parent interaction)
interaction_list <- lapply(interaction_list, function(x) x[!is.na(x)])
# Remove non-interactions (same variable)
interaction_list <- lapply(interaction_list, unique) # remove same variables
interaction_length <- sapply(interaction_list, length)
interaction_list <- interaction_list[interaction_length > 1]
interaction_list <- unique(lapply(interaction_list, sort))
return(interaction_list)
}
# Generate sample data
x <- list()
for (i in 1:10){
x[[i]] = i*rnorm(1000, 10)
}
x <- as.data.table(x)
y = -1*x[, rowSums(.SD)] + x[['V1']]*x[['V2']] + x[['V3']]*x[['V4']]*x[['V5']] + rnorm(1000, 0.001) + 3*sin(x[['V7']])
train = as.matrix(x)
# Interaction constraint list (column names form)
interaction_list <- list(c('V1','V2'),c('V3','V4','V5'))
# Convert interaction constraint list into feature index form
cols2ids <- function(object, col_names) {
LUT <- seq_along(col_names) - 1
names(LUT) <- col_names
rapply(object, function(x) LUT[x], classes="character", how="replace")
}
interaction_list_fid = cols2ids(interaction_list, colnames(train))
# Fit model with interaction constraints
bst = xgboost(data = train, label = y, max_depth = 4,
eta = 0.1, nthread = 2, nrounds = 1000,
interaction_constraints = interaction_list_fid)
bst_tree <- xgb.model.dt.tree(colnames(train), bst)
bst_interactions <- treeInteractions(bst_tree, 4) # interactions constrained to combinations of V1*V2 and V3*V4*V5
# Fit model without interaction constraints
bst2 = xgboost(data = train, label = y, max_depth = 4,
eta = 0.1, nthread = 2, nrounds = 1000)
bst2_tree <- xgb.model.dt.tree(colnames(train), bst2)
bst2_interactions <- treeInteractions(bst2_tree, 4) # much more interactions
# Fit model with both interaction and monotonicity constraints
bst3 = xgboost(data = train, label = y, max_depth = 4,
eta = 0.1, nthread = 2, nrounds = 1000,
interaction_constraints = interaction_list_fid,
monotone_constraints = c(-1,0,0,0,0,0,0,0,0,0))
bst3_tree <- xgb.model.dt.tree(colnames(train), bst3)
bst3_interactions <- treeInteractions(bst3_tree, 4) # interactions still constrained to combinations of V1*V2 and V3*V4*V5
# Show monotonic constraints still apply by checking scores after incrementing V1
x1 <- sort(unique(x[['V1']]))
for (i in 1:length(x1)){
testdata <- copy(x[, -c('V1')])
testdata[['V1']] <- x1[i]
testdata <- testdata[, paste0('V',1:10), with=F]
pred <- predict(bst3, as.matrix(testdata))
# Should not print out anything due to monotonic constraints
if (i > 1) if (any(pred > prev_pred)) print(i)
prev_pred <- pred
}

View File

@ -0,0 +1,38 @@
require(xgboost)
context("interaction constraints")
set.seed(1024)
x1 <- rnorm(1000, 1)
x2 <- rnorm(1000, 1)
x3 <- sample(c(1,2,3), size=1000, replace=TRUE)
y <- x1 + x2 + x3 + x1*x2*x3 + rnorm(1000, 0.001) + 3*sin(x1)
train <- matrix(c(x1,x2,x3), ncol = 3)
test_that("interaction constraints for regression", {
# Fit a model that only allows interaction between x1 and x2
bst <- xgboost(data = train, label = y, max_depth = 3,
eta = 0.1, nthread = 2, nrounds = 100, verbose = 0,
interaction_constraints = list(c(0,1)))
# Set all observations to have the same x3 values then increment
# by the same amount
preds <- lapply(c(1,2,3), function(x){
tmat <- matrix(c(x1,x2,rep(x,1000)), ncol=3)
return(predict(bst, tmat))
})
# Check incrementing x3 has the same effect on all observations
# since x3 is constrained to be independent of x1 and x2
# and all observations start off from the same x3 value
diff1 <- preds[[2]] - preds[[1]]
test1 <- all(abs(diff1 - diff1[1]) < 1e-4)
diff2 <- preds[[3]] - preds[[2]]
test2 <- all(abs(diff2 - diff2[1]) < 1e-4)
expect_true({
test1 & test2
}, "Interaction Contraint Satisfied")
})

View File

@ -41,7 +41,7 @@ sys.path.insert(0, curr_path)
# -- mock out modules
import mock
MOCK_MODULES = ['numpy', 'scipy', 'scipy.sparse', 'sklearn', 'matplotlib', 'pandas', 'graphviz']
MOCK_MODULES = ['scipy', 'scipy.sparse', 'sklearn', 'pandas']
for mod_name in MOCK_MODULES:
sys.modules[mod_name] = mock.Mock()
@ -62,6 +62,8 @@ release = xgboost.__version__
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones
extensions = [
'matplotlib.sphinxext.only_directives',
'matplotlib.sphinxext.plot_directive',
'sphinx.ext.autodoc',
'sphinx.ext.napoleon',
'sphinx.ext.mathjax',
@ -69,6 +71,11 @@ extensions = [
'breathe'
]
graphviz_output_format = 'png'
plot_formats = [('svg', 300), ('png', 100), ('hires.png', 300)]
plot_html_show_source_link = False
plot_html_show_formats = False
# Breathe extension variables
breathe_projects = {"xgboost": "doxyxml/"}
breathe_default_project = "xgboost"

View File

@ -0,0 +1,177 @@
###############################
Feature Interaction Constraints
###############################
The decision tree is a powerful tool to discover interaction among independent
variables (features). Variables that appear together in a traversal path
are interacting with one another, since the condition of a child node is
predicated on the condition of the parent node. For example, the highlighted
red path in the diagram below contains three variables: :math:`x_1`, :math:`x_7`,
and :math:`x_{10}`, so the highlighted prediction (at the highlighted leaf node)
is the product of interaction between :math:`x_1`, :math:`x_7`, and
:math:`x_{10}`.
.. plot::
:nofigs:
from graphviz import Source
source = r"""
digraph feature_interaction_illustration1 {
graph [fontname = "helvetica"];
node [fontname = "helvetica"];
edge [fontname = "helvetica"];
0 [label=<x<SUB><FONT POINT-SIZE="11">10</FONT></SUB> &lt; -1.5 ?>, shape=box, color=red, fontcolor=red];
1 [label=<x<SUB><FONT POINT-SIZE="11">2</FONT></SUB> &lt; 2 ?>, shape=box];
2 [label=<x<SUB><FONT POINT-SIZE="11">7</FONT></SUB> &lt; 0.3 ?>, shape=box, color=red, fontcolor=red];
3 [label="...", shape=none];
4 [label="...", shape=none];
5 [label=<x<SUB><FONT POINT-SIZE="11">1</FONT></SUB> &lt; 0.5 ?>, shape=box, color=red, fontcolor=red];
6 [label="...", shape=none];
7 [label="...", shape=none];
8 [label="Predict +1.3", color=red, fontcolor=red];
0 -> 1 [labeldistance=2.0, labelangle=45, headlabel="Yes/Missing "];
0 -> 2 [labeldistance=2.0, labelangle=-45,
headlabel="No", color=red, fontcolor=red];
1 -> 3 [labeldistance=2.0, labelangle=45, headlabel="Yes"];
1 -> 4 [labeldistance=2.0, labelangle=-45, headlabel=" No/Missing"];
2 -> 5 [labeldistance=2.0, labelangle=-45, headlabel="Yes",
color=red, fontcolor=red];
2 -> 6 [labeldistance=2.0, labelangle=-45, headlabel=" No/Missing"];
5 -> 7;
5 -> 8 [color=red];
}
"""
Source(source, format='png').render('../_static/feature_interaction_illustration1', view=False)
Source(source, format='svg').render('../_static/feature_interaction_illustration1', view=False)
.. raw:: html
<p>
<img src="../_static/feature_interaction_illustration1.svg"
onerror="this.src='../_static/feature_interaction_illustration1.png'; this.onerror=null;">
</p>
When the tree depth is larger than one, many variables interact on
the sole basis of minimizing training loss, and the resulting decision tree may
capture a spurious relationship (noise) rather than a legitimate relationship
that generalizes across different datasets. **Feature interaction constraints**
allow users to decide which variables are allowed to interact and which are not.
Potential benefits include:
* Better predictive performance from focusing on interactions that work --
whether through domain specific knowledge or algorithms that rank interactions
* Less noise in predictions; better generalization
* More control to the user on what the model can fit. For example, the user may
want to exclude some interactions even if they perform well due to regulatory
constraints
****************
A Simple Example
****************
Feature interaction constraints are expressed in terms of groups of variables
that are allowed to interact. For example, the constraint
``[0, 1]`` indicates that variables :math:`x_0` and :math:`x_1` are allowed to
interact with each other but with no other variable. Similarly, ``[2, 3, 4]``
indicates that :math:`x_2`, :math:`x_3`, and :math:`x_4` are allowed to
interact with one another but with no other variable. A set of feature
interaction constraints is expressed as a nested list, e.g.
``[[0, 1], [2, 3, 4]]``, where each inner list is a group of indices of features
that are allowed to interact with each other.
In the following diagram, the left decision tree is in violation of the first
constraint (``[0, 1]``), whereas the right decision tree complies with both the
first and second constraints (``[0, 1]``, ``[2, 3, 4]``).
.. plot::
:nofigs:
from graphviz import Source
source = r"""
digraph feature_interaction_illustration2 {
graph [fontname = "helvetica"];
node [fontname = "helvetica"];
edge [fontname = "helvetica"];
0 [label=<x<SUB><FONT POINT-SIZE="11">0</FONT></SUB> &lt; 5.0 ?>, shape=box];
1 [label=<x<SUB><FONT POINT-SIZE="11">2</FONT></SUB> &lt; -3.0 ?>, shape=box];
2 [label="+0.6"];
3 [label="-0.4"];
4 [label="+1.2"];
0 -> 1 [labeldistance=2.0, labelangle=45, headlabel="Yes/Missing "];
0 -> 2 [labeldistance=2.0, labelangle=-45, headlabel="No"];
1 -> 3 [labeldistance=2.0, labelangle=45, headlabel="Yes"];
1 -> 4 [labeldistance=2.0, labelangle=-45, headlabel=" No/Missing"];
}
"""
Source(source, format='png').render('../_static/feature_interaction_illustration2', view=False)
Source(source, format='svg').render('../_static/feature_interaction_illustration2', view=False)
.. plot::
:nofigs:
from graphviz import Source
source = r"""
digraph feature_interaction_illustration3 {
graph [fontname = "helvetica"];
node [fontname = "helvetica"];
edge [fontname = "helvetica"];
0 [label=<x<SUB><FONT POINT-SIZE="11">3</FONT></SUB> &lt; 2.5 ?>, shape=box];
1 [label="+1.6"];
2 [label=<x<SUB><FONT POINT-SIZE="11">2</FONT></SUB> &lt; -1.2 ?>, shape=box];
3 [label="+0.1"];
4 [label="-0.3"];
0 -> 1 [labeldistance=2.0, labelangle=45, headlabel="Yes"];
0 -> 2 [labeldistance=2.0, labelangle=-45, headlabel=" No/Missing"];
2 -> 3 [labeldistance=2.0, labelangle=45, headlabel="Yes/Missing "];
2 -> 4 [labeldistance=2.0, labelangle=-45, headlabel="No"];
}
"""
Source(source, format='png').render('../_static/feature_interaction_illustration3', view=False)
Source(source, format='svg').render('../_static/feature_interaction_illustration3', view=False)
.. raw:: html
<p>
<img src="../_static/feature_interaction_illustration2.svg"
onerror="this.src='../_static/feature_interaction_illustration2.png'; this.onerror=null;">
<img src="../_static/feature_interaction_illustration3.svg"
onerror="this.src='../_static/feature_interaction_illustration3.png'; this.onerror=null;">
</p>
****************************************************
Enforcing Feature Interaction Constraints in XGBoost
****************************************************
It is very simple to enforce monotonicity constraints in XGBoost. Here we will
give an example using Python, but the same general idea generalizes to other
platforms.
Suppose the following code fits your model without monotonicity constraints:
.. code-block:: python
model_no_constraints = xgb.train(params, dtrain,
num_boost_round = 1000, evals = evallist,
early_stopping_rounds = 10)
Then fitting with monotonicity constraints only requires adding a single
parameter:
.. code-block:: python
params_constrained = params.copy()
# Use nested list to define feature interaction constraints
params_constrained['interaction_constraints'] = '[[0, 2], [1, 3, 4], [5, 6]]'
# Features 0 and 2 are allowed to interact with each other but with no other feature
# Features 1, 3, 4 are allowed to interact with one another but with no other feature
# Features 5 and 6 are allowed to interact with each other but with no other feature
model_with_constraints = xgb.train(params_constrained, dtrain,
num_boost_round = 1000, evals = evallist,
early_stopping_rounds = 10)
**Choice of tree construction algorithm**. To use feature interaction
constraints, be sure to set the ``tree_method`` parameter to either ``exact``
or ``hist``. Currently, GPU algorithms (``gpu_hist``, ``gpu_exact``) do not
support feature interaction constraints.

View File

@ -14,6 +14,7 @@ See `Awesome XGBoost <https://github.com/dmlc/xgboost/tree/master/demo>`_ for mo
Distributed XGBoost with XGBoost4J-Spark <https://xgboost.readthedocs.io/en/latest/jvm/xgboost4j_spark_tutorial.html>
dart
monotonic
feature_interaction_constraint
input_format
param_tuning
external_memory

View File

@ -82,7 +82,7 @@ Some other examples:
- ``(1,0)``: An increasing constraint on the first predictor and no constraint on the second.
- ``(0,-1)``: No constraint on the first predictor and a decreasing constraint on the second.
**Choise of tree construction algorithm**. To use monotonic constraints, be
**Choice of tree construction algorithm**. To use monotonic constraints, be
sure to set the ``tree_method`` parameter to one of ``exact``, ``hist``, and
``gpu_hist``.

View File

@ -194,7 +194,7 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
.describe("Number of rows in a GPU batch, used for finding quantiles on GPU; "
"-1 to use all rows assignted to a GPU, and 0 to auto-deduce");
DMLC_DECLARE_FIELD(split_evaluator)
.set_default("elastic_net,monotonic")
.set_default("elastic_net,monotonic,interaction")
.describe("The criteria to use for ranking splits");
// add alias of parameters
DMLC_DECLARE_ALIAS(reg_lambda, lambda);

View File

@ -4,8 +4,11 @@
* \brief Contains implementations of different split evaluators.
*/
#include "split_evaluator.h"
#include <dmlc/json.h>
#include <dmlc/registry.h>
#include <algorithm>
#include <unordered_set>
#include <vector>
#include <limits>
#include <string>
#include <sstream>
@ -303,5 +306,196 @@ XGBOOST_REGISTER_SPLIT_EVALUATOR(MonotonicConstraint, "monotonic")
return new MonotonicConstraint(std::move(inner));
});
/*! \brief Encapsulates the parameters required by the InteractionConstraint
split evaluator
*/
struct InteractionConstraintParams
: public dmlc::Parameter<InteractionConstraintParams> {
std::string interaction_constraints;
bst_uint num_feature;
DMLC_DECLARE_PARAMETER(InteractionConstraintParams) {
DMLC_DECLARE_FIELD(interaction_constraints)
.set_default("")
.describe("Constraints for interaction representing permitted interactions."
"The constraints must be specified in the form of a nest list,"
"e.g. [[0, 1], [2, 3, 4]], where each inner list is a group of"
"indices of features that are allowed to interact with each other."
"See tutorial for more information");
DMLC_DECLARE_FIELD(num_feature)
.describe("Number of total features used");
}
};
DMLC_REGISTER_PARAMETER(InteractionConstraintParams);
/*! \brief Enforces that the tree is monotonically increasing/decreasing with respect to a user specified set of
features.
*/
class InteractionConstraint final : public SplitEvaluator {
public:
explicit InteractionConstraint(std::unique_ptr<SplitEvaluator> inner) {
if (!inner) {
LOG(FATAL) << "InteractionConstraint must be given an inner evaluator";
}
inner_ = std::move(inner);
}
void Init(const std::vector<std::pair<std::string, std::string> >& args)
override {
inner_->Init(args);
params_.InitAllowUnknown(args);
Reset();
}
void Reset() override {
if (params_.interaction_constraints.empty()) {
return; // short-circuit if no constraint is specified
}
// Parse interaction constraints
std::istringstream iss(params_.interaction_constraints);
dmlc::JSONReader reader(&iss);
// Read std::vector<std::vector<bst_uint>> first and then
// convert to std::vector<std::unordered_set<bst_uint>>
std::vector<std::vector<bst_uint>> tmp;
reader.Read(&tmp);
for (const auto& e : tmp) {
interaction_constraints_.emplace_back(e.begin(), e.end());
}
// Initialise interaction constraints record with all variables permitted for the first node
int_cont_.clear();
int_cont_.resize(1, std::unordered_set<bst_uint>());
int_cont_[0].reserve(params_.num_feature);
for (bst_uint i = 0; i < params_.num_feature; ++i) {
int_cont_[0].insert(i);
}
// Initialise splits record
splits_.clear();
splits_.resize(1, std::unordered_set<bst_uint>());
}
SplitEvaluator* GetHostClone() const override {
if (params_.interaction_constraints.empty()) {
// No interaction constraints specified, just return a clone of inner
return inner_->GetHostClone();
} else {
auto c = new InteractionConstraint(
std::unique_ptr<SplitEvaluator>(inner_->GetHostClone()));
c->params_ = this->params_;
c->Reset();
return c;
}
}
bst_float ComputeSplitScore(bst_uint nodeid,
bst_uint featureid,
const GradStats& left_stats,
const GradStats& right_stats,
bst_float left_weight,
bst_float right_weight) const override {
// Return negative infinity score if feature is not permitted by interaction constraints
if (!CheckInteractionConstraint(featureid, nodeid)) {
return -std::numeric_limits<bst_float>::infinity();
}
// Otherwise, get score from inner evaluator
bst_float score = inner_->ComputeSplitScore(
nodeid, featureid, left_stats, right_stats, left_weight, right_weight);
return score;
}
bst_float ComputeScore(bst_uint parentID, const GradStats& stats, bst_float weight)
const override {
return inner_->ComputeScore(parentID, stats, weight);
}
bst_float ComputeWeight(bst_uint parentID, const GradStats& stats)
const override {
return inner_->ComputeWeight(parentID, stats);
}
void AddSplit(bst_uint nodeid,
bst_uint leftid,
bst_uint rightid,
bst_uint featureid,
bst_float leftweight,
bst_float rightweight) override {
inner_->AddSplit(nodeid, leftid, rightid, featureid, leftweight, rightweight);
if (params_.interaction_constraints.empty()) {
return; // short-circuit if no constraint is specified
}
bst_uint newsize = std::max(leftid, rightid) + 1;
// Record previous splits for child nodes
std::unordered_set<bst_uint> feature_splits = splits_[nodeid]; // fid history of current node
feature_splits.insert(featureid); // add feature of current node
splits_.resize(newsize);
splits_[leftid] = feature_splits;
splits_[rightid] = feature_splits;
// Resize constraints record, initialise all features to be not permitted for new nodes
int_cont_.resize(newsize, std::unordered_set<bst_uint>());
// Permit features used in previous splits
for (bst_uint fid : feature_splits) {
int_cont_[leftid].insert(fid);
int_cont_[rightid].insert(fid);
}
// Loop across specified interactions in constraints
for (const auto& constraint : interaction_constraints_) {
bst_uint flag = 1; // flags whether the specified interaction is still relevant
// Test relevance of specified interaction by checking all previous features are included
for (bst_uint checkvar : feature_splits) {
if (constraint.count(checkvar) == 0) {
flag = 0;
break; // interaction is not relevant due to unmet constraint
}
}
// If interaction is still relevant, permit all other features in the interaction
if (flag == 1) {
for (bst_uint k : constraint) {
int_cont_[leftid].insert(k);
int_cont_[rightid].insert(k);
}
}
}
}
private:
InteractionConstraintParams params_;
std::unique_ptr<SplitEvaluator> inner_;
// interaction_constraints_[constraint_id] contains a single interaction
// constraint, which specifies a group of feature IDs that can interact
// with each other
std::vector< std::unordered_set<bst_uint> > interaction_constraints_;
// int_cont_[nid] contains the set of all feature IDs that are allowed to
// be used for a split at node nid
std::vector< std::unordered_set<bst_uint> > int_cont_;
// splits_[nid] contains the set of all feature IDs that have been used for
// splits in node nid and its parents
std::vector< std::unordered_set<bst_uint> > splits_;
// Check interaction constraints. Returns true if a given feature ID is
// permissible in a given node; returns false otherwise
inline bool CheckInteractionConstraint(bst_uint featureid, bst_uint nodeid) const {
// short-circuit if no constraint is specified
return (params_.interaction_constraints.empty()
|| int_cont_[nodeid].count(featureid) > 0);
}
};
XGBOOST_REGISTER_SPLIT_EVALUATOR(InteractionConstraint, "interaction")
.describe("Enforces interaction constraints on tree features")
.set_body([](std::unique_ptr<SplitEvaluator> inner) {
return new InteractionConstraint(std::move(inner));
});
} // namespace tree
} // namespace xgboost

View File

@ -0,0 +1,40 @@
# -*- coding: utf-8 -*-
import numpy as np
import xgboost
import unittest
dpath = 'demo/data/'
rng = np.random.RandomState(1994)
class TestInteractionConstraints(unittest.TestCase):
def test_interaction_constraints(self):
x1 = np.random.normal(loc=1.0, scale=1.0, size=1000)
x2 = np.random.normal(loc=1.0, scale=1.0, size=1000)
x3 = np.random.choice([1, 2, 3], size=1000, replace=True)
y = x1 + x2 + x3 + x1 * x2 * x3 \
+ np.random.normal(loc=0.001, scale=1.0, size=1000) + 3 * np.sin(x1)
X = np.column_stack((x1, x2, x3))
dtrain = xgboost.DMatrix(X, label=y)
params = {'max_depth': 3, 'eta': 0.1, 'nthread': 2, 'silent': 1,
'interaction_constraints': '[[0, 1]]'}
num_boost_round = 100
# Fit a model that only allows interaction between x1 and x2
bst = xgboost.train(params, dtrain, num_boost_round, evals=[(dtrain, 'train')])
# Set all observations to have the same x3 values then increment
# by the same amount
def f(x):
tmat = xgboost.DMatrix(np.column_stack((x1, x2, np.repeat(x, 1000))))
return bst.predict(tmat)
preds = [f(x) for x in [1, 2, 3]]
# Check incrementing x3 has the same effect on all observations
# since x3 is constrained to be independent of x1 and x2
# and all observations start off from the same x3 value
diff1 = preds[1] - preds[0]
assert np.all(np.abs(diff1 - diff1[0]) < 1e-4)
diff2 = preds[2] - preds[1]
assert np.all(np.abs(diff2 - diff2[0]) < 1e-4)