Add prediction of feature contributions (#2003)
* Add prediction of feature contributions This implements the idea described at http://blog.datadive.net/interpreting-random-forests/ which tries to give insight in how a prediction is composed of its feature contributions and a bias. * Support multi-class models * Calculate learning_rate per-tree instead of using the one from the first tree * Do not rely on node.base_weight * learning_rate having the same value as the node mean value (aka leaf value, if it were a leaf); instead calculate them (lazily) on-the-fly * Add simple test for contributions feature * Check against param.num_nodes instead of checking for non-zero length * Loop over all roots instead of only the first
This commit is contained in:
parent
e62be19c70
commit
6bd1869026
@ -382,6 +382,7 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle,
|
||||
* 0:normal prediction
|
||||
* 1:output margin instead of transformed value
|
||||
* 2:output leaf index of trees instead of leaf value, note leaf index is unique per tree
|
||||
* 4:output feature contributions of all trees instead of predictions
|
||||
* \param ntree_limit limit number of trees used for prediction, this is only valid for boosted trees
|
||||
* when the parameter is set to 0, we will use all the trees
|
||||
* \param out_len used to store length of returning result
|
||||
|
||||
@ -107,6 +107,19 @@ class GradientBooster {
|
||||
virtual void PredictLeaf(DMatrix* dmat,
|
||||
std::vector<bst_float>* out_preds,
|
||||
unsigned ntree_limit = 0) = 0;
|
||||
|
||||
/*!
|
||||
* \brief predict the feature contributions of each tree, the output will be nsample * (nfeats + 1) vector
|
||||
* this is only valid in gbtree predictor
|
||||
* \param dmat feature matrix
|
||||
* \param out_contribs output vector to hold the contributions
|
||||
* \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means
|
||||
* we do not limit number of trees
|
||||
*/
|
||||
virtual void PredictContribution(DMatrix* dmat,
|
||||
std::vector<bst_float>* out_contribs,
|
||||
unsigned ntree_limit = 0) = 0;
|
||||
|
||||
/*!
|
||||
* \brief dump the model in the requested format
|
||||
* \param fmap feature map that may help give interpretations of feature
|
||||
|
||||
@ -103,12 +103,14 @@ class Learner : public rabit::Serializable {
|
||||
* \param ntree_limit limit number of trees used for boosted tree
|
||||
* predictor, when it equals 0, this means we are using all the trees
|
||||
* \param pred_leaf whether to only predict the leaf index of each tree in a boosted tree predictor
|
||||
* \param pred_contribs whether to only predict the feature contributions of all trees
|
||||
*/
|
||||
virtual void Predict(DMatrix* data,
|
||||
bool output_margin,
|
||||
std::vector<bst_float> *out_preds,
|
||||
unsigned ntree_limit = 0,
|
||||
bool pred_leaf = false) const = 0;
|
||||
bool pred_leaf = false,
|
||||
bool pred_contribs = false) const = 0;
|
||||
/*!
|
||||
* \brief Set additional attribute to the Booster.
|
||||
* The property will be saved along the booster.
|
||||
|
||||
@ -434,6 +434,11 @@ class RegTree: public TreeModel<bst_float, RTreeNodeStat> {
|
||||
* \param inst The sparse instance to drop.
|
||||
*/
|
||||
inline void Drop(const RowBatch::Inst& inst);
|
||||
/*!
|
||||
* \brief returns the size of the feature vector
|
||||
* \return the size of the feature vector
|
||||
*/
|
||||
inline size_t size() const;
|
||||
/*!
|
||||
* \brief get ith value
|
||||
* \param i feature index.
|
||||
@ -472,6 +477,14 @@ class RegTree: public TreeModel<bst_float, RTreeNodeStat> {
|
||||
* \return the leaf index of the given feature
|
||||
*/
|
||||
inline bst_float Predict(const FVec& feat, unsigned root_id = 0) const;
|
||||
/*!
|
||||
* \brief calculate the feature contributions for the given root
|
||||
* \param feat dense feature vector, if the feature is missing the field is set to NaN
|
||||
* \param root_id starting root index of the instance
|
||||
* \param out_contribs output vector to hold the contributions
|
||||
*/
|
||||
inline void CalculateContributions(const RegTree::FVec& feat, unsigned root_id,
|
||||
bst_float *out_contribs) const;
|
||||
/*!
|
||||
* \brief get next position of the tree given current pid
|
||||
* \param pid Current node id.
|
||||
@ -489,6 +502,15 @@ class RegTree: public TreeModel<bst_float, RTreeNodeStat> {
|
||||
std::string DumpModel(const FeatureMap& fmap,
|
||||
bool with_stats,
|
||||
std::string format) const;
|
||||
/*!
|
||||
* \brief calculate the mean value for each node, required for feature contributions
|
||||
*/
|
||||
inline void FillNodeMeanValues();
|
||||
|
||||
private:
|
||||
inline bst_float FillNodeMeanValue(int nid);
|
||||
|
||||
std::vector<bst_float> node_mean_values;
|
||||
};
|
||||
|
||||
// implementations of inline functions
|
||||
@ -513,6 +535,10 @@ inline void RegTree::FVec::Drop(const RowBatch::Inst& inst) {
|
||||
}
|
||||
}
|
||||
|
||||
inline size_t RegTree::FVec::size() const {
|
||||
return data.size();
|
||||
}
|
||||
|
||||
inline bst_float RegTree::FVec::fvalue(size_t i) const {
|
||||
return data[i].fvalue;
|
||||
}
|
||||
@ -535,6 +561,58 @@ inline bst_float RegTree::Predict(const RegTree::FVec& feat, unsigned root_id) c
|
||||
return (*this)[pid].leaf_value();
|
||||
}
|
||||
|
||||
inline void RegTree::FillNodeMeanValues() {
|
||||
size_t num_nodes = this->param.num_nodes;
|
||||
if (this->node_mean_values.size() == num_nodes) {
|
||||
return;
|
||||
}
|
||||
this->node_mean_values.resize(num_nodes);
|
||||
for (int root_id = 0; root_id < param.num_roots; ++root_id) {
|
||||
this->FillNodeMeanValue(root_id);
|
||||
}
|
||||
}
|
||||
|
||||
inline bst_float RegTree::FillNodeMeanValue(int nid) {
|
||||
bst_float result;
|
||||
auto& node = (*this)[nid];
|
||||
if (node.is_leaf()) {
|
||||
result = node.leaf_value();
|
||||
} else {
|
||||
result = this->FillNodeMeanValue(node.cleft()) * this->stat(node.cleft()).sum_hess;
|
||||
result += this->FillNodeMeanValue(node.cright()) * this->stat(node.cright()).sum_hess;
|
||||
result /= this->stat(nid).sum_hess;
|
||||
}
|
||||
this->node_mean_values[nid] = result;
|
||||
return result;
|
||||
}
|
||||
|
||||
inline void RegTree::CalculateContributions(const RegTree::FVec& feat, unsigned root_id,
|
||||
bst_float *out_contribs) const {
|
||||
CHECK_GT(this->node_mean_values.size(), 0);
|
||||
// this follows the idea of http://blog.datadive.net/interpreting-random-forests/
|
||||
bst_float node_value;
|
||||
unsigned split_index;
|
||||
int pid = static_cast<int>(root_id);
|
||||
// update bias value
|
||||
node_value = this->node_mean_values[pid];
|
||||
out_contribs[feat.size()] += node_value;
|
||||
if ((*this)[pid].is_leaf()) {
|
||||
// nothing to do anymore
|
||||
return;
|
||||
}
|
||||
while (!(*this)[pid].is_leaf()) {
|
||||
split_index = (*this)[pid].split_index();
|
||||
pid = this->GetNext(pid, feat.fvalue(split_index), feat.is_missing(split_index));
|
||||
bst_float new_value = this->node_mean_values[pid];
|
||||
// update feature weight
|
||||
out_contribs[split_index] += new_value - node_value;
|
||||
node_value = new_value;
|
||||
}
|
||||
bst_float leaf_value = (*this)[pid].leaf_value();
|
||||
// update leaf feature weight
|
||||
out_contribs[split_index] += leaf_value - node_value;
|
||||
}
|
||||
|
||||
/*! \brief get next position of the tree given current pid */
|
||||
inline int RegTree::GetNext(int pid, bst_float fvalue, bool is_unknown) const {
|
||||
bst_float split_value = (*this)[pid].split_cond();
|
||||
|
||||
@ -911,7 +911,8 @@ class Booster(object):
|
||||
self._validate_features(data)
|
||||
return self.eval_set([(data, name)], iteration)
|
||||
|
||||
def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False):
|
||||
def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False,
|
||||
pred_contribs=False):
|
||||
"""
|
||||
Predict with data.
|
||||
|
||||
@ -937,6 +938,12 @@ class Booster(object):
|
||||
Note that the leaf index of a tree is unique per tree, so you may find leaf 1
|
||||
in both tree 1 and tree 0.
|
||||
|
||||
pred_contribs : bool
|
||||
When this option is on, the output will be a matrix of (nsample, nfeats+1)
|
||||
with each record indicating the feature contributions of all trees. The sum of
|
||||
all feature contributions is equal to the prediction. Note that the bias is added
|
||||
as the final column, on top of the regular features.
|
||||
|
||||
Returns
|
||||
-------
|
||||
prediction : numpy array
|
||||
@ -946,6 +953,8 @@ class Booster(object):
|
||||
option_mask |= 0x01
|
||||
if pred_leaf:
|
||||
option_mask |= 0x02
|
||||
if pred_contribs:
|
||||
option_mask |= 0x04
|
||||
|
||||
self._validate_features(data)
|
||||
|
||||
|
||||
@ -622,7 +622,8 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
||||
static_cast<std::shared_ptr<DMatrix>*>(dmat)->get(),
|
||||
(option_mask & 1) != 0,
|
||||
&preds, ntree_limit,
|
||||
(option_mask & 2) != 0);
|
||||
(option_mask & 2) != 0,
|
||||
(option_mask & 4) != 0);
|
||||
*out_result = dmlc::BeginPtr(preds);
|
||||
*len = static_cast<xgboost::bst_ulong>(preds.size());
|
||||
API_END();
|
||||
|
||||
@ -223,6 +223,11 @@ class GBLinear : public GradientBooster {
|
||||
unsigned ntree_limit) override {
|
||||
LOG(FATAL) << "gblinear does not support predict leaf index";
|
||||
}
|
||||
void PredictContribution(DMatrix* p_fmat,
|
||||
std::vector<bst_float>* out_contribs,
|
||||
unsigned ntree_limit) override {
|
||||
LOG(FATAL) << "gblinear does not support predict contributions";
|
||||
}
|
||||
|
||||
std::vector<std::string> DumpModel(const FeatureMap& fmap,
|
||||
bool with_stats,
|
||||
|
||||
@ -322,6 +322,14 @@ class GBTree : public GradientBooster {
|
||||
this->PredPath(p_fmat, out_preds, ntree_limit);
|
||||
}
|
||||
|
||||
void PredictContribution(DMatrix* p_fmat,
|
||||
std::vector<bst_float>* out_contribs,
|
||||
unsigned ntree_limit) override {
|
||||
const int nthread = omp_get_max_threads();
|
||||
InitThreadTemp(nthread);
|
||||
this->PredContrib(p_fmat, out_contribs, ntree_limit);
|
||||
}
|
||||
|
||||
std::vector<std::string> DumpModel(const FeatureMap& fmap,
|
||||
bool with_stats,
|
||||
std::string format) const override {
|
||||
@ -553,6 +561,62 @@ class GBTree : public GradientBooster {
|
||||
}
|
||||
}
|
||||
}
|
||||
// predict contributions
|
||||
inline void PredContrib(DMatrix *p_fmat,
|
||||
std::vector<bst_float> *out_contribs,
|
||||
unsigned ntree_limit) {
|
||||
const MetaInfo& info = p_fmat->info();
|
||||
// number of valid trees
|
||||
ntree_limit *= mparam.num_output_group;
|
||||
if (ntree_limit == 0 || ntree_limit > trees.size()) {
|
||||
ntree_limit = static_cast<unsigned>(trees.size());
|
||||
}
|
||||
size_t ncolumns = mparam.num_feature + 1;
|
||||
// allocate space for (number of features + bias) times the number of rows
|
||||
std::vector<bst_float>& contribs = *out_contribs;
|
||||
contribs.resize(info.num_row * ncolumns * mparam.num_output_group);
|
||||
// make sure contributions is zeroed, we could be reusing a previously allocated one
|
||||
std::fill(contribs.begin(), contribs.end(), 0);
|
||||
// initialize tree node mean values
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i=0; i < ntree_limit; ++i) {
|
||||
trees[i]->FillNodeMeanValues();
|
||||
}
|
||||
// start collecting the contributions
|
||||
dmlc::DataIter<RowBatch>* iter = p_fmat->RowIterator();
|
||||
const std::vector<bst_float>& base_margin = p_fmat->info().base_margin;
|
||||
iter->BeforeFirst();
|
||||
while (iter->Next()) {
|
||||
const RowBatch& batch = iter->Value();
|
||||
// parallel over local batch
|
||||
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||
size_t row_idx = static_cast<size_t>(batch.base_rowid + i);
|
||||
unsigned root_id = info.GetRoot(row_idx);
|
||||
RegTree::FVec &feats = thread_temp[omp_get_thread_num()];
|
||||
// loop over all classes
|
||||
for (int gid = 0; gid < mparam.num_output_group; ++gid) {
|
||||
bst_float *p_contribs = &contribs[(row_idx * mparam.num_output_group + gid) * ncolumns];
|
||||
feats.Fill(batch[i]);
|
||||
// calculate contributions
|
||||
for (unsigned j = 0; j < ntree_limit; ++j) {
|
||||
if (tree_info[j] != gid) {
|
||||
continue;
|
||||
}
|
||||
trees[j]->CalculateContributions(feats, root_id, p_contribs);
|
||||
}
|
||||
feats.Drop(batch[i]);
|
||||
// add base margin to BIAS feature
|
||||
if (base_margin.size() != 0) {
|
||||
p_contribs[ncolumns - 1] += base_margin[row_idx * mparam.num_output_group + gid];
|
||||
} else {
|
||||
p_contribs[ncolumns - 1] += base_margin_;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// init thread buffers
|
||||
inline void InitThreadTemp(int nthread) {
|
||||
int prev_thread_temp_size = thread_temp.size();
|
||||
|
||||
@ -400,8 +400,11 @@ class LearnerImpl : public Learner {
|
||||
bool output_margin,
|
||||
std::vector<bst_float> *out_preds,
|
||||
unsigned ntree_limit,
|
||||
bool pred_leaf) const override {
|
||||
if (pred_leaf) {
|
||||
bool pred_leaf,
|
||||
bool pred_contribs) const override {
|
||||
if (pred_contribs) {
|
||||
gbm_->PredictContribution(data, out_preds, ntree_limit);
|
||||
} else if (pred_leaf) {
|
||||
gbm_->PredictLeaf(data, out_preds, ntree_limit);
|
||||
} else {
|
||||
this->PredictRaw(data, out_preds, ntree_limit);
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
import unittest
|
||||
import itertools
|
||||
import json
|
||||
|
||||
dpath = 'demo/data/'
|
||||
@ -250,3 +251,26 @@ class TestBasic(unittest.TestCase):
|
||||
cv = xgb.cv(params, dm, num_boost_round=10, shuffle=False, nfold=10, as_pandas=False)
|
||||
assert isinstance(cv, dict)
|
||||
assert len(cv) == (4)
|
||||
|
||||
|
||||
def test_contributions():
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||
|
||||
def test_fn(max_depth, num_rounds):
|
||||
# train
|
||||
params = {'max_depth': max_depth, 'eta': 1, 'silent': 1}
|
||||
bst = xgb.train(params, dtrain, num_boost_round=num_rounds)
|
||||
|
||||
# predict
|
||||
preds = bst.predict(dtest)
|
||||
contribs = bst.predict(dtest, pred_contribs=True)
|
||||
|
||||
# result should be (number of features + BIAS) * number of rows
|
||||
assert contribs.shape == (dtest.num_row(), dtest.num_col() + 1)
|
||||
|
||||
# sum of contributions should be same as predictions
|
||||
np.testing.assert_array_almost_equal(np.sum(contribs, axis=1), preds)
|
||||
|
||||
for max_depth, num_rounds in itertools.product(range(0, 3), range(1, 5)):
|
||||
yield test_fn, max_depth, num_rounds
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user