Add prediction of feature contributions (#2003)

* Add prediction of feature contributions

This implements the idea described at http://blog.datadive.net/interpreting-random-forests/
which tries to give insight in how a prediction is composed of its feature contributions
and a bias.

* Support multi-class models

* Calculate learning_rate per-tree instead of using the one from the first tree

* Do not rely on node.base_weight * learning_rate having the same value as the node mean value (aka leaf value, if it were a leaf); instead calculate them (lazily) on-the-fly

* Add simple test for contributions feature

* Check against param.num_nodes instead of checking for non-zero length

* Loop over all roots instead of only the first
This commit is contained in:
Maurus Cuelenaere
2017-05-14 07:58:10 +02:00
committed by Vadim Khotilovich
parent e62be19c70
commit 6bd1869026
10 changed files with 205 additions and 5 deletions

View File

@@ -382,6 +382,7 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle,
* 0:normal prediction
* 1:output margin instead of transformed value
* 2:output leaf index of trees instead of leaf value, note leaf index is unique per tree
* 4:output feature contributions of all trees instead of predictions
* \param ntree_limit limit number of trees used for prediction, this is only valid for boosted trees
* when the parameter is set to 0, we will use all the trees
* \param out_len used to store length of returning result

View File

@@ -107,6 +107,19 @@ class GradientBooster {
virtual void PredictLeaf(DMatrix* dmat,
std::vector<bst_float>* out_preds,
unsigned ntree_limit = 0) = 0;
/*!
* \brief predict the feature contributions of each tree, the output will be nsample * (nfeats + 1) vector
* this is only valid in gbtree predictor
* \param dmat feature matrix
* \param out_contribs output vector to hold the contributions
* \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means
* we do not limit number of trees
*/
virtual void PredictContribution(DMatrix* dmat,
std::vector<bst_float>* out_contribs,
unsigned ntree_limit = 0) = 0;
/*!
* \brief dump the model in the requested format
* \param fmap feature map that may help give interpretations of feature

View File

@@ -103,12 +103,14 @@ class Learner : public rabit::Serializable {
* \param ntree_limit limit number of trees used for boosted tree
* predictor, when it equals 0, this means we are using all the trees
* \param pred_leaf whether to only predict the leaf index of each tree in a boosted tree predictor
* \param pred_contribs whether to only predict the feature contributions of all trees
*/
virtual void Predict(DMatrix* data,
bool output_margin,
std::vector<bst_float> *out_preds,
unsigned ntree_limit = 0,
bool pred_leaf = false) const = 0;
bool pred_leaf = false,
bool pred_contribs = false) const = 0;
/*!
* \brief Set additional attribute to the Booster.
* The property will be saved along the booster.

View File

@@ -434,6 +434,11 @@ class RegTree: public TreeModel<bst_float, RTreeNodeStat> {
* \param inst The sparse instance to drop.
*/
inline void Drop(const RowBatch::Inst& inst);
/*!
* \brief returns the size of the feature vector
* \return the size of the feature vector
*/
inline size_t size() const;
/*!
* \brief get ith value
* \param i feature index.
@@ -472,6 +477,14 @@ class RegTree: public TreeModel<bst_float, RTreeNodeStat> {
* \return the leaf index of the given feature
*/
inline bst_float Predict(const FVec& feat, unsigned root_id = 0) const;
/*!
* \brief calculate the feature contributions for the given root
* \param feat dense feature vector, if the feature is missing the field is set to NaN
* \param root_id starting root index of the instance
* \param out_contribs output vector to hold the contributions
*/
inline void CalculateContributions(const RegTree::FVec& feat, unsigned root_id,
bst_float *out_contribs) const;
/*!
* \brief get next position of the tree given current pid
* \param pid Current node id.
@@ -489,6 +502,15 @@ class RegTree: public TreeModel<bst_float, RTreeNodeStat> {
std::string DumpModel(const FeatureMap& fmap,
bool with_stats,
std::string format) const;
/*!
* \brief calculate the mean value for each node, required for feature contributions
*/
inline void FillNodeMeanValues();
private:
inline bst_float FillNodeMeanValue(int nid);
std::vector<bst_float> node_mean_values;
};
// implementations of inline functions
@@ -513,6 +535,10 @@ inline void RegTree::FVec::Drop(const RowBatch::Inst& inst) {
}
}
inline size_t RegTree::FVec::size() const {
return data.size();
}
inline bst_float RegTree::FVec::fvalue(size_t i) const {
return data[i].fvalue;
}
@@ -535,6 +561,58 @@ inline bst_float RegTree::Predict(const RegTree::FVec& feat, unsigned root_id) c
return (*this)[pid].leaf_value();
}
inline void RegTree::FillNodeMeanValues() {
size_t num_nodes = this->param.num_nodes;
if (this->node_mean_values.size() == num_nodes) {
return;
}
this->node_mean_values.resize(num_nodes);
for (int root_id = 0; root_id < param.num_roots; ++root_id) {
this->FillNodeMeanValue(root_id);
}
}
inline bst_float RegTree::FillNodeMeanValue(int nid) {
bst_float result;
auto& node = (*this)[nid];
if (node.is_leaf()) {
result = node.leaf_value();
} else {
result = this->FillNodeMeanValue(node.cleft()) * this->stat(node.cleft()).sum_hess;
result += this->FillNodeMeanValue(node.cright()) * this->stat(node.cright()).sum_hess;
result /= this->stat(nid).sum_hess;
}
this->node_mean_values[nid] = result;
return result;
}
inline void RegTree::CalculateContributions(const RegTree::FVec& feat, unsigned root_id,
bst_float *out_contribs) const {
CHECK_GT(this->node_mean_values.size(), 0);
// this follows the idea of http://blog.datadive.net/interpreting-random-forests/
bst_float node_value;
unsigned split_index;
int pid = static_cast<int>(root_id);
// update bias value
node_value = this->node_mean_values[pid];
out_contribs[feat.size()] += node_value;
if ((*this)[pid].is_leaf()) {
// nothing to do anymore
return;
}
while (!(*this)[pid].is_leaf()) {
split_index = (*this)[pid].split_index();
pid = this->GetNext(pid, feat.fvalue(split_index), feat.is_missing(split_index));
bst_float new_value = this->node_mean_values[pid];
// update feature weight
out_contribs[split_index] += new_value - node_value;
node_value = new_value;
}
bst_float leaf_value = (*this)[pid].leaf_value();
// update leaf feature weight
out_contribs[split_index] += leaf_value - node_value;
}
/*! \brief get next position of the tree given current pid */
inline int RegTree::GetNext(int pid, bst_float fvalue, bool is_unknown) const {
bst_float split_value = (*this)[pid].split_cond();