SHAP values for feature contributions (#2438)

* SHAP values for feature contributions

* Fix commenting error

* New polynomial time SHAP value estimation algorithm

* Update API to support SHAP values

* Fix merge conflicts with updates in master

* Correct submodule hashes

* Fix variable sized stack allocation

* Make lint happy

* Add docs

* Fix typo

* Adjust tolerances

* Remove unneeded def

* Fixed cpp test setup

* Updated R API and cleaned up

* Fixed test typo
This commit is contained in:
Scott Lundberg
2017-10-12 12:35:51 -07:00
committed by GitHub
parent ff9180cd73
commit 78c4188cec
16 changed files with 369 additions and 143 deletions

View File

@@ -115,10 +115,11 @@ class GradientBooster {
* \param out_contribs output vector to hold the contributions
* \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means
* we do not limit number of trees
* \param approximate use a faster (inconsistent) approximation of SHAP values
*/
virtual void PredictContribution(DMatrix* dmat,
std::vector<bst_float>* out_contribs,
unsigned ntree_limit = 0) = 0;
unsigned ntree_limit = 0, bool approximate = false) = 0;
/*!
* \brief dump the model in the requested format

View File

@@ -104,13 +104,15 @@ class Learner : public rabit::Serializable {
* predictor, when it equals 0, this means we are using all the trees
* \param pred_leaf whether to only predict the leaf index of each tree in a boosted tree predictor
* \param pred_contribs whether to only predict the feature contributions
* \param approx_contribs whether to approximate the feature contributions for speed
*/
virtual void Predict(DMatrix* data,
bool output_margin,
std::vector<bst_float> *out_preds,
unsigned ntree_limit = 0,
bool pred_leaf = false,
bool pred_contribs = false) const = 0;
bool pred_contribs = false,
bool approx_contribs = false) const = 0;
/*!
* \brief Set additional attribute to the Booster.
* The property will be saved along the booster.

View File

@@ -144,12 +144,14 @@ class Predictor {
* \param [in,out] out_contribs The output feature contribs.
* \param model Model to make predictions from.
* \param ntree_limit (Optional) The ntree limit.
* \param approximate Use fast approximate algorithm.
*/
virtual void PredictContribution(DMatrix* dmat,
std::vector<bst_float>* out_contribs,
const gbm::GBTreeModel& model,
unsigned ntree_limit = 0) = 0;
unsigned ntree_limit = 0,
bool approximate = false) = 0;
/**
* \fn static Predictor* Predictor::Create(std::string name);

View File

@@ -14,6 +14,7 @@
#include <string>
#include <cstring>
#include <algorithm>
#include <tuple>
#include "./base.h"
#include "./data.h"
#include "./logging.h"
@@ -411,6 +412,20 @@ struct RTreeNodeStat {
int leaf_child_cnt;
};
// Used by TreeShap
// data we keep about our decision path
// note that pweight is included for convenience and is not tied with the other attributes
// the pweight of the i'th path element is the permuation weight of paths with i-1 ones in them
struct PathElement {
int feature_index;
bst_float zero_fraction;
bst_float one_fraction;
bst_float pweight;
PathElement() {}
PathElement(int i, bst_float z, bst_float o, bst_float w) :
feature_index(i), zero_fraction(z), one_fraction(o), pweight(w) {}
};
/*!
* \brief define regression tree to be the most common tree model.
* This is the data structure used in xgboost's major tree models.
@@ -482,13 +497,26 @@ class RegTree: public TreeModel<bst_float, RTreeNodeStat> {
*/
inline bst_float Predict(const FVec& feat, unsigned root_id = 0) const;
/*!
* \brief calculate the feature contributions for the given root
* \brief calculate the feature contributions (https://arxiv.org/abs/1706.06060) for the tree
* \param feat dense feature vector, if the feature is missing the field is set to NaN
* \param root_id starting root index of the instance
* \param out_contribs output vector to hold the contributions
*/
inline void CalculateContributions(const RegTree::FVec& feat, unsigned root_id,
bst_float *out_contribs) const;
inline void TreeShap(const RegTree::FVec& feat, bst_float *phi,
unsigned node_index, unsigned unique_depth,
PathElement *parent_unique_path, bst_float parent_zero_fraction,
bst_float parent_one_fraction, int parent_feature_index) const;
/*!
* \brief calculate the approximate feature contributions for the given root
* \param feat dense feature vector, if the feature is missing the field is set to NaN
* \param root_id starting root index of the instance
* \param out_contribs output vector to hold the contributions
*/
inline void CalculateContributionsApprox(const RegTree::FVec& feat, unsigned root_id,
bst_float *out_contribs) const;
/*!
* \brief get next position of the tree given current pid
* \param pid Current node id.
@@ -590,7 +618,7 @@ inline bst_float RegTree::FillNodeMeanValue(int nid) {
return result;
}
inline void RegTree::CalculateContributions(const RegTree::FVec& feat, unsigned root_id,
inline void RegTree::CalculateContributionsApprox(const RegTree::FVec& feat, unsigned root_id,
bst_float *out_contribs) const {
CHECK_GT(this->node_mean_values.size(), 0U);
// this follows the idea of http://blog.datadive.net/interpreting-random-forests/
@@ -617,6 +645,154 @@ inline void RegTree::CalculateContributions(const RegTree::FVec& feat, unsigned
out_contribs[split_index] += leaf_value - node_value;
}
// extend our decision path with a fraction of one and zero extensions
inline void ExtendPath(PathElement *unique_path, unsigned unique_depth,
bst_float zero_fraction, bst_float one_fraction, int feature_index) {
unique_path[unique_depth].feature_index = feature_index;
unique_path[unique_depth].zero_fraction = zero_fraction;
unique_path[unique_depth].one_fraction = one_fraction;
unique_path[unique_depth].pweight = (unique_depth == 0 ? 1 : 0);
for (int i = unique_depth-1; i >= 0; i--) {
unique_path[i+1].pweight += one_fraction*unique_path[i].pweight*(i+1)
/ static_cast<bst_float>(unique_depth+1);
unique_path[i].pweight = zero_fraction*unique_path[i].pweight*(unique_depth-i)
/ static_cast<bst_float>(unique_depth+1);
}
}
// undo a previous extension of the decision path
inline void UnwindPath(PathElement *unique_path, unsigned unique_depth, unsigned path_index) {
const bst_float one_fraction = unique_path[path_index].one_fraction;
const bst_float zero_fraction = unique_path[path_index].zero_fraction;
bst_float next_one_portion = unique_path[unique_depth].pweight;
for (int i = unique_depth-1; i >= 0; --i) {
if (one_fraction != 0) {
const bst_float tmp = unique_path[i].pweight;
unique_path[i].pweight = next_one_portion*(unique_depth+1)
/ static_cast<bst_float>((i+1)*one_fraction);
next_one_portion = tmp - unique_path[i].pweight*zero_fraction*(unique_depth-i)
/ static_cast<bst_float>(unique_depth+1);
} else {
unique_path[i].pweight = (unique_path[i].pweight*(unique_depth+1))
/ static_cast<bst_float>(zero_fraction*(unique_depth-i));
}
}
for (int i = path_index; i < unique_depth; ++i) {
unique_path[i].feature_index = unique_path[i+1].feature_index;
unique_path[i].zero_fraction = unique_path[i+1].zero_fraction;
unique_path[i].one_fraction = unique_path[i+1].one_fraction;
}
}
// determine what the total permuation weight would be if
// we unwound a previous extension in the decision path
inline bst_float UnwoundPathSum(const PathElement *unique_path, unsigned unique_depth,
unsigned path_index) {
const bst_float one_fraction = unique_path[path_index].one_fraction;
const bst_float zero_fraction = unique_path[path_index].zero_fraction;
bst_float next_one_portion = unique_path[unique_depth].pweight;
bst_float total = 0;
for (int i = unique_depth-1; i >= 0; --i) {
if (one_fraction != 0) {
const bst_float tmp = next_one_portion*(unique_depth+1)
/ static_cast<bst_float>((i+1)*one_fraction);
total += tmp;
next_one_portion = unique_path[i].pweight - tmp*zero_fraction*((unique_depth-i)
/ static_cast<bst_float>(unique_depth+1));
} else {
total += (unique_path[i].pweight/zero_fraction)/((unique_depth-i)
/ static_cast<bst_float>(unique_depth+1));
}
}
return total;
}
// recursive computation of SHAP values for a decision tree
inline void RegTree::TreeShap(const RegTree::FVec& feat, bst_float *phi,
unsigned node_index, unsigned unique_depth,
PathElement *parent_unique_path, bst_float parent_zero_fraction,
bst_float parent_one_fraction, int parent_feature_index) const {
const auto node = (*this)[node_index];
// extend the unique path
PathElement *unique_path = parent_unique_path + unique_depth;
if (unique_depth > 0) std::copy(parent_unique_path, parent_unique_path+unique_depth, unique_path);
ExtendPath(unique_path, unique_depth, parent_zero_fraction,
parent_one_fraction, parent_feature_index);
const unsigned split_index = node.split_index();
// leaf node
if (node.is_leaf()) {
for (int i = 1; i <= unique_depth; ++i) {
const bst_float w = UnwoundPathSum(unique_path, unique_depth, i);
const PathElement &el = unique_path[i];
phi[el.feature_index] += w*(el.one_fraction-el.zero_fraction)*node.leaf_value();
}
// internal node
} else {
// find which branch is "hot" (meaning x would follow it)
unsigned hot_index = 0;
if (feat.is_missing(split_index)) {
hot_index = node.cdefault();
} else if (feat.fvalue(split_index) < node.split_cond()) {
hot_index = node.cleft();
} else {
hot_index = node.cright();
}
const unsigned cold_index = (hot_index == node.cleft() ? node.cright() : node.cleft());
const bst_float w = this->stat(node_index).sum_hess;
const bst_float hot_zero_fraction = this->stat(hot_index).sum_hess/w;
const bst_float cold_zero_fraction = this->stat(cold_index).sum_hess/w;
bst_float incoming_zero_fraction = 1;
bst_float incoming_one_fraction = 1;
// see if we have already split on this feature,
// if so we undo that split so we can redo it for this node
unsigned path_index = 0;
for (; path_index <= unique_depth; ++path_index) {
if (unique_path[path_index].feature_index == split_index) break;
}
if (path_index != unique_depth+1) {
incoming_zero_fraction = unique_path[path_index].zero_fraction;
incoming_one_fraction = unique_path[path_index].one_fraction;
UnwindPath(unique_path, unique_depth, path_index);
unique_depth -= 1;
}
TreeShap(feat, phi, hot_index, unique_depth+1, unique_path,
hot_zero_fraction*incoming_zero_fraction, incoming_one_fraction, split_index);
TreeShap(feat, phi, cold_index, unique_depth+1, unique_path,
cold_zero_fraction*incoming_zero_fraction, 0, split_index);
}
}
inline void RegTree::CalculateContributions(const RegTree::FVec& feat, unsigned root_id,
bst_float *out_contribs) const {
// find the expected value of the tree's predictions
bst_float base_value = 0.0;
bst_float total_cover = 0;
for (unsigned i = 0; i < (*this).param.num_nodes; ++i) {
const auto node = (*this)[i];
if (node.is_leaf()) {
const auto cover = this->stat(i).sum_hess;
base_value += cover*node.leaf_value();
total_cover += cover;
}
}
out_contribs[feat.size()] += base_value / total_cover;
// Preallocate space for the unique path data
const int maxd = this->MaxDepth(root_id)+1;
PathElement *unique_path_data = new PathElement[(maxd*(maxd+1))/2];
TreeShap(feat, out_contribs, root_id, 0, unique_path_data, 1, 1, -1);
delete[] unique_path_data;
}
/*! \brief get next position of the tree given current pid */
inline int RegTree::GetNext(int pid, bst_float fvalue, bool is_unknown) const {
bst_float split_value = (*this)[pid].split_cond();