Add prediction of feature contributions (#2003)

* Add prediction of feature contributions

This implements the idea described at http://blog.datadive.net/interpreting-random-forests/
which tries to give insight in how a prediction is composed of its feature contributions
and a bias.

* Support multi-class models

* Calculate learning_rate per-tree instead of using the one from the first tree

* Do not rely on node.base_weight * learning_rate having the same value as the node mean value (aka leaf value, if it were a leaf); instead calculate them (lazily) on-the-fly

* Add simple test for contributions feature

* Check against param.num_nodes instead of checking for non-zero length

* Loop over all roots instead of only the first
This commit is contained in:
Maurus Cuelenaere 2017-05-14 07:58:10 +02:00 committed by Vadim Khotilovich
parent e62be19c70
commit 6bd1869026
10 changed files with 205 additions and 5 deletions

View File

@ -382,6 +382,7 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle,
* 0:normal prediction * 0:normal prediction
* 1:output margin instead of transformed value * 1:output margin instead of transformed value
* 2:output leaf index of trees instead of leaf value, note leaf index is unique per tree * 2:output leaf index of trees instead of leaf value, note leaf index is unique per tree
* 4:output feature contributions of all trees instead of predictions
* \param ntree_limit limit number of trees used for prediction, this is only valid for boosted trees * \param ntree_limit limit number of trees used for prediction, this is only valid for boosted trees
* when the parameter is set to 0, we will use all the trees * when the parameter is set to 0, we will use all the trees
* \param out_len used to store length of returning result * \param out_len used to store length of returning result

View File

@ -107,6 +107,19 @@ class GradientBooster {
virtual void PredictLeaf(DMatrix* dmat, virtual void PredictLeaf(DMatrix* dmat,
std::vector<bst_float>* out_preds, std::vector<bst_float>* out_preds,
unsigned ntree_limit = 0) = 0; unsigned ntree_limit = 0) = 0;
/*!
* \brief predict the feature contributions of each tree, the output will be nsample * (nfeats + 1) vector
* this is only valid in gbtree predictor
* \param dmat feature matrix
* \param out_contribs output vector to hold the contributions
* \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means
* we do not limit number of trees
*/
virtual void PredictContribution(DMatrix* dmat,
std::vector<bst_float>* out_contribs,
unsigned ntree_limit = 0) = 0;
/*! /*!
* \brief dump the model in the requested format * \brief dump the model in the requested format
* \param fmap feature map that may help give interpretations of feature * \param fmap feature map that may help give interpretations of feature

View File

@ -103,12 +103,14 @@ class Learner : public rabit::Serializable {
* \param ntree_limit limit number of trees used for boosted tree * \param ntree_limit limit number of trees used for boosted tree
* predictor, when it equals 0, this means we are using all the trees * predictor, when it equals 0, this means we are using all the trees
* \param pred_leaf whether to only predict the leaf index of each tree in a boosted tree predictor * \param pred_leaf whether to only predict the leaf index of each tree in a boosted tree predictor
* \param pred_contribs whether to only predict the feature contributions of all trees
*/ */
virtual void Predict(DMatrix* data, virtual void Predict(DMatrix* data,
bool output_margin, bool output_margin,
std::vector<bst_float> *out_preds, std::vector<bst_float> *out_preds,
unsigned ntree_limit = 0, unsigned ntree_limit = 0,
bool pred_leaf = false) const = 0; bool pred_leaf = false,
bool pred_contribs = false) const = 0;
/*! /*!
* \brief Set additional attribute to the Booster. * \brief Set additional attribute to the Booster.
* The property will be saved along the booster. * The property will be saved along the booster.

View File

@ -434,6 +434,11 @@ class RegTree: public TreeModel<bst_float, RTreeNodeStat> {
* \param inst The sparse instance to drop. * \param inst The sparse instance to drop.
*/ */
inline void Drop(const RowBatch::Inst& inst); inline void Drop(const RowBatch::Inst& inst);
/*!
* \brief returns the size of the feature vector
* \return the size of the feature vector
*/
inline size_t size() const;
/*! /*!
* \brief get ith value * \brief get ith value
* \param i feature index. * \param i feature index.
@ -472,6 +477,14 @@ class RegTree: public TreeModel<bst_float, RTreeNodeStat> {
* \return the leaf index of the given feature * \return the leaf index of the given feature
*/ */
inline bst_float Predict(const FVec& feat, unsigned root_id = 0) const; inline bst_float Predict(const FVec& feat, unsigned root_id = 0) const;
/*!
* \brief calculate the feature contributions for the given root
* \param feat dense feature vector, if the feature is missing the field is set to NaN
* \param root_id starting root index of the instance
* \param out_contribs output vector to hold the contributions
*/
inline void CalculateContributions(const RegTree::FVec& feat, unsigned root_id,
bst_float *out_contribs) const;
/*! /*!
* \brief get next position of the tree given current pid * \brief get next position of the tree given current pid
* \param pid Current node id. * \param pid Current node id.
@ -489,6 +502,15 @@ class RegTree: public TreeModel<bst_float, RTreeNodeStat> {
std::string DumpModel(const FeatureMap& fmap, std::string DumpModel(const FeatureMap& fmap,
bool with_stats, bool with_stats,
std::string format) const; std::string format) const;
/*!
* \brief calculate the mean value for each node, required for feature contributions
*/
inline void FillNodeMeanValues();
private:
inline bst_float FillNodeMeanValue(int nid);
std::vector<bst_float> node_mean_values;
}; };
// implementations of inline functions // implementations of inline functions
@ -513,6 +535,10 @@ inline void RegTree::FVec::Drop(const RowBatch::Inst& inst) {
} }
} }
inline size_t RegTree::FVec::size() const {
return data.size();
}
inline bst_float RegTree::FVec::fvalue(size_t i) const { inline bst_float RegTree::FVec::fvalue(size_t i) const {
return data[i].fvalue; return data[i].fvalue;
} }
@ -535,6 +561,58 @@ inline bst_float RegTree::Predict(const RegTree::FVec& feat, unsigned root_id) c
return (*this)[pid].leaf_value(); return (*this)[pid].leaf_value();
} }
inline void RegTree::FillNodeMeanValues() {
size_t num_nodes = this->param.num_nodes;
if (this->node_mean_values.size() == num_nodes) {
return;
}
this->node_mean_values.resize(num_nodes);
for (int root_id = 0; root_id < param.num_roots; ++root_id) {
this->FillNodeMeanValue(root_id);
}
}
inline bst_float RegTree::FillNodeMeanValue(int nid) {
bst_float result;
auto& node = (*this)[nid];
if (node.is_leaf()) {
result = node.leaf_value();
} else {
result = this->FillNodeMeanValue(node.cleft()) * this->stat(node.cleft()).sum_hess;
result += this->FillNodeMeanValue(node.cright()) * this->stat(node.cright()).sum_hess;
result /= this->stat(nid).sum_hess;
}
this->node_mean_values[nid] = result;
return result;
}
inline void RegTree::CalculateContributions(const RegTree::FVec& feat, unsigned root_id,
bst_float *out_contribs) const {
CHECK_GT(this->node_mean_values.size(), 0);
// this follows the idea of http://blog.datadive.net/interpreting-random-forests/
bst_float node_value;
unsigned split_index;
int pid = static_cast<int>(root_id);
// update bias value
node_value = this->node_mean_values[pid];
out_contribs[feat.size()] += node_value;
if ((*this)[pid].is_leaf()) {
// nothing to do anymore
return;
}
while (!(*this)[pid].is_leaf()) {
split_index = (*this)[pid].split_index();
pid = this->GetNext(pid, feat.fvalue(split_index), feat.is_missing(split_index));
bst_float new_value = this->node_mean_values[pid];
// update feature weight
out_contribs[split_index] += new_value - node_value;
node_value = new_value;
}
bst_float leaf_value = (*this)[pid].leaf_value();
// update leaf feature weight
out_contribs[split_index] += leaf_value - node_value;
}
/*! \brief get next position of the tree given current pid */ /*! \brief get next position of the tree given current pid */
inline int RegTree::GetNext(int pid, bst_float fvalue, bool is_unknown) const { inline int RegTree::GetNext(int pid, bst_float fvalue, bool is_unknown) const {
bst_float split_value = (*this)[pid].split_cond(); bst_float split_value = (*this)[pid].split_cond();

View File

@ -911,7 +911,8 @@ class Booster(object):
self._validate_features(data) self._validate_features(data)
return self.eval_set([(data, name)], iteration) return self.eval_set([(data, name)], iteration)
def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False): def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False,
pred_contribs=False):
""" """
Predict with data. Predict with data.
@ -937,6 +938,12 @@ class Booster(object):
Note that the leaf index of a tree is unique per tree, so you may find leaf 1 Note that the leaf index of a tree is unique per tree, so you may find leaf 1
in both tree 1 and tree 0. in both tree 1 and tree 0.
pred_contribs : bool
When this option is on, the output will be a matrix of (nsample, nfeats+1)
with each record indicating the feature contributions of all trees. The sum of
all feature contributions is equal to the prediction. Note that the bias is added
as the final column, on top of the regular features.
Returns Returns
------- -------
prediction : numpy array prediction : numpy array
@ -946,6 +953,8 @@ class Booster(object):
option_mask |= 0x01 option_mask |= 0x01
if pred_leaf: if pred_leaf:
option_mask |= 0x02 option_mask |= 0x02
if pred_contribs:
option_mask |= 0x04
self._validate_features(data) self._validate_features(data)

View File

@ -622,7 +622,8 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
static_cast<std::shared_ptr<DMatrix>*>(dmat)->get(), static_cast<std::shared_ptr<DMatrix>*>(dmat)->get(),
(option_mask & 1) != 0, (option_mask & 1) != 0,
&preds, ntree_limit, &preds, ntree_limit,
(option_mask & 2) != 0); (option_mask & 2) != 0,
(option_mask & 4) != 0);
*out_result = dmlc::BeginPtr(preds); *out_result = dmlc::BeginPtr(preds);
*len = static_cast<xgboost::bst_ulong>(preds.size()); *len = static_cast<xgboost::bst_ulong>(preds.size());
API_END(); API_END();

View File

@ -223,6 +223,11 @@ class GBLinear : public GradientBooster {
unsigned ntree_limit) override { unsigned ntree_limit) override {
LOG(FATAL) << "gblinear does not support predict leaf index"; LOG(FATAL) << "gblinear does not support predict leaf index";
} }
void PredictContribution(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs,
unsigned ntree_limit) override {
LOG(FATAL) << "gblinear does not support predict contributions";
}
std::vector<std::string> DumpModel(const FeatureMap& fmap, std::vector<std::string> DumpModel(const FeatureMap& fmap,
bool with_stats, bool with_stats,

View File

@ -322,6 +322,14 @@ class GBTree : public GradientBooster {
this->PredPath(p_fmat, out_preds, ntree_limit); this->PredPath(p_fmat, out_preds, ntree_limit);
} }
void PredictContribution(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs,
unsigned ntree_limit) override {
const int nthread = omp_get_max_threads();
InitThreadTemp(nthread);
this->PredContrib(p_fmat, out_contribs, ntree_limit);
}
std::vector<std::string> DumpModel(const FeatureMap& fmap, std::vector<std::string> DumpModel(const FeatureMap& fmap,
bool with_stats, bool with_stats,
std::string format) const override { std::string format) const override {
@ -553,6 +561,62 @@ class GBTree : public GradientBooster {
} }
} }
} }
// predict contributions
inline void PredContrib(DMatrix *p_fmat,
std::vector<bst_float> *out_contribs,
unsigned ntree_limit) {
const MetaInfo& info = p_fmat->info();
// number of valid trees
ntree_limit *= mparam.num_output_group;
if (ntree_limit == 0 || ntree_limit > trees.size()) {
ntree_limit = static_cast<unsigned>(trees.size());
}
size_t ncolumns = mparam.num_feature + 1;
// allocate space for (number of features + bias) times the number of rows
std::vector<bst_float>& contribs = *out_contribs;
contribs.resize(info.num_row * ncolumns * mparam.num_output_group);
// make sure contributions is zeroed, we could be reusing a previously allocated one
std::fill(contribs.begin(), contribs.end(), 0);
// initialize tree node mean values
#pragma omp parallel for schedule(static)
for (bst_omp_uint i=0; i < ntree_limit; ++i) {
trees[i]->FillNodeMeanValues();
}
// start collecting the contributions
dmlc::DataIter<RowBatch>* iter = p_fmat->RowIterator();
const std::vector<bst_float>& base_margin = p_fmat->info().base_margin;
iter->BeforeFirst();
while (iter->Next()) {
const RowBatch& batch = iter->Value();
// parallel over local batch
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < nsize; ++i) {
size_t row_idx = static_cast<size_t>(batch.base_rowid + i);
unsigned root_id = info.GetRoot(row_idx);
RegTree::FVec &feats = thread_temp[omp_get_thread_num()];
// loop over all classes
for (int gid = 0; gid < mparam.num_output_group; ++gid) {
bst_float *p_contribs = &contribs[(row_idx * mparam.num_output_group + gid) * ncolumns];
feats.Fill(batch[i]);
// calculate contributions
for (unsigned j = 0; j < ntree_limit; ++j) {
if (tree_info[j] != gid) {
continue;
}
trees[j]->CalculateContributions(feats, root_id, p_contribs);
}
feats.Drop(batch[i]);
// add base margin to BIAS feature
if (base_margin.size() != 0) {
p_contribs[ncolumns - 1] += base_margin[row_idx * mparam.num_output_group + gid];
} else {
p_contribs[ncolumns - 1] += base_margin_;
}
}
}
}
}
// init thread buffers // init thread buffers
inline void InitThreadTemp(int nthread) { inline void InitThreadTemp(int nthread) {
int prev_thread_temp_size = thread_temp.size(); int prev_thread_temp_size = thread_temp.size();

View File

@ -400,8 +400,11 @@ class LearnerImpl : public Learner {
bool output_margin, bool output_margin,
std::vector<bst_float> *out_preds, std::vector<bst_float> *out_preds,
unsigned ntree_limit, unsigned ntree_limit,
bool pred_leaf) const override { bool pred_leaf,
if (pred_leaf) { bool pred_contribs) const override {
if (pred_contribs) {
gbm_->PredictContribution(data, out_preds, ntree_limit);
} else if (pred_leaf) {
gbm_->PredictLeaf(data, out_preds, ntree_limit); gbm_->PredictLeaf(data, out_preds, ntree_limit);
} else { } else {
this->PredictRaw(data, out_preds, ntree_limit); this->PredictRaw(data, out_preds, ntree_limit);

View File

@ -2,6 +2,7 @@
import numpy as np import numpy as np
import xgboost as xgb import xgboost as xgb
import unittest import unittest
import itertools
import json import json
dpath = 'demo/data/' dpath = 'demo/data/'
@ -250,3 +251,26 @@ class TestBasic(unittest.TestCase):
cv = xgb.cv(params, dm, num_boost_round=10, shuffle=False, nfold=10, as_pandas=False) cv = xgb.cv(params, dm, num_boost_round=10, shuffle=False, nfold=10, as_pandas=False)
assert isinstance(cv, dict) assert isinstance(cv, dict)
assert len(cv) == (4) assert len(cv) == (4)
def test_contributions():
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
def test_fn(max_depth, num_rounds):
# train
params = {'max_depth': max_depth, 'eta': 1, 'silent': 1}
bst = xgb.train(params, dtrain, num_boost_round=num_rounds)
# predict
preds = bst.predict(dtest)
contribs = bst.predict(dtest, pred_contribs=True)
# result should be (number of features + BIAS) * number of rows
assert contribs.shape == (dtest.num_row(), dtest.num_col() + 1)
# sum of contributions should be same as predictions
np.testing.assert_array_almost_equal(np.sum(contribs, axis=1), preds)
for max_depth, num_rounds in itertools.product(range(0, 3), range(1, 5)):
yield test_fn, max_depth, num_rounds