Split up SHAP from RegTree. (#8612)
* Split up SHAP from `RegTree`. Simplify the tree interface.
This commit is contained in:
parent
d308124910
commit
beefd28471
@ -53,6 +53,7 @@ OBJECTS= \
|
|||||||
$(PKGROOT)/src/data/iterative_dmatrix.o \
|
$(PKGROOT)/src/data/iterative_dmatrix.o \
|
||||||
$(PKGROOT)/src/predictor/predictor.o \
|
$(PKGROOT)/src/predictor/predictor.o \
|
||||||
$(PKGROOT)/src/predictor/cpu_predictor.o \
|
$(PKGROOT)/src/predictor/cpu_predictor.o \
|
||||||
|
$(PKGROOT)/src/predictor/cpu_treeshap.o \
|
||||||
$(PKGROOT)/src/tree/constraints.o \
|
$(PKGROOT)/src/tree/constraints.o \
|
||||||
$(PKGROOT)/src/tree/param.o \
|
$(PKGROOT)/src/tree/param.o \
|
||||||
$(PKGROOT)/src/tree/fit_stump.o \
|
$(PKGROOT)/src/tree/fit_stump.o \
|
||||||
|
|||||||
@ -53,6 +53,7 @@ OBJECTS= \
|
|||||||
$(PKGROOT)/src/data/iterative_dmatrix.o \
|
$(PKGROOT)/src/data/iterative_dmatrix.o \
|
||||||
$(PKGROOT)/src/predictor/predictor.o \
|
$(PKGROOT)/src/predictor/predictor.o \
|
||||||
$(PKGROOT)/src/predictor/cpu_predictor.o \
|
$(PKGROOT)/src/predictor/cpu_predictor.o \
|
||||||
|
$(PKGROOT)/src/predictor/cpu_treeshap.o \
|
||||||
$(PKGROOT)/src/tree/constraints.o \
|
$(PKGROOT)/src/tree/constraints.o \
|
||||||
$(PKGROOT)/src/tree/param.o \
|
$(PKGROOT)/src/tree/param.o \
|
||||||
$(PKGROOT)/src/tree/fit_stump.o \
|
$(PKGROOT)/src/tree/fit_stump.o \
|
||||||
|
|||||||
@ -542,37 +542,6 @@ class RegTree : public Model {
|
|||||||
bool has_missing_;
|
bool has_missing_;
|
||||||
};
|
};
|
||||||
|
|
||||||
/*!
|
|
||||||
* \brief calculate the feature contributions (https://arxiv.org/abs/1706.06060) for the tree
|
|
||||||
* \param feat dense feature vector, if the feature is missing the field is set to NaN
|
|
||||||
* \param out_contribs output vector to hold the contributions
|
|
||||||
* \param condition fix one feature to either off (-1) on (1) or not fixed (0 default)
|
|
||||||
* \param condition_feature the index of the feature to fix
|
|
||||||
*/
|
|
||||||
void CalculateContributions(const RegTree::FVec& feat,
|
|
||||||
std::vector<float>* mean_values,
|
|
||||||
bst_float* out_contribs, int condition = 0,
|
|
||||||
unsigned condition_feature = 0) const;
|
|
||||||
/*!
|
|
||||||
* \brief Recursive function that computes the feature attributions for a single tree.
|
|
||||||
* \param feat dense feature vector, if the feature is missing the field is set to NaN
|
|
||||||
* \param phi dense output vector of feature attributions
|
|
||||||
* \param node_index the index of the current node in the tree
|
|
||||||
* \param unique_depth how many unique features are above the current node in the tree
|
|
||||||
* \param parent_unique_path a vector of statistics about our current path through the tree
|
|
||||||
* \param parent_zero_fraction what fraction of the parent path weight is coming as 0 (integrated)
|
|
||||||
* \param parent_one_fraction what fraction of the parent path weight is coming as 1 (fixed)
|
|
||||||
* \param parent_feature_index what feature the parent node used to split
|
|
||||||
* \param condition fix one feature to either off (-1) on (1) or not fixed (0 default)
|
|
||||||
* \param condition_feature the index of the feature to fix
|
|
||||||
* \param condition_fraction what fraction of the current weight matches our conditioning feature
|
|
||||||
*/
|
|
||||||
void TreeShap(const RegTree::FVec& feat, bst_float* phi, bst_node_t node_index,
|
|
||||||
unsigned unique_depth, PathElement* parent_unique_path,
|
|
||||||
bst_float parent_zero_fraction, bst_float parent_one_fraction,
|
|
||||||
int parent_feature_index, int condition,
|
|
||||||
unsigned condition_feature, bst_float condition_fraction) const;
|
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief calculate the approximate feature contributions for the given root
|
* \brief calculate the approximate feature contributions for the given root
|
||||||
* \param feat dense feature vector, if the feature is missing the field is set to NaN
|
* \param feat dense feature vector, if the feature is missing the field is set to NaN
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
#include "../data/gradient_index.h"
|
#include "../data/gradient_index.h"
|
||||||
#include "../data/proxy_dmatrix.h"
|
#include "../data/proxy_dmatrix.h"
|
||||||
#include "../gbm/gbtree_model.h"
|
#include "../gbm/gbtree_model.h"
|
||||||
|
#include "cpu_treeshap.h" // CalculateContributions
|
||||||
#include "predict_fn.h"
|
#include "predict_fn.h"
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
@ -530,9 +531,8 @@ class CPUPredictor : public Predictor {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (!approximate) {
|
if (!approximate) {
|
||||||
model.trees[j]->CalculateContributions(
|
CalculateContributions(*model.trees[j], feats, tree_mean_values,
|
||||||
feats, tree_mean_values, &this_tree_contribs[0], condition,
|
&this_tree_contribs[0], condition, condition_feature);
|
||||||
condition_feature);
|
|
||||||
} else {
|
} else {
|
||||||
model.trees[j]->CalculateContributionsApprox(
|
model.trees[j]->CalculateContributionsApprox(
|
||||||
feats, tree_mean_values, &this_tree_contribs[0]);
|
feats, tree_mean_values, &this_tree_contribs[0]);
|
||||||
|
|||||||
202
src/predictor/cpu_treeshap.cc
Normal file
202
src/predictor/cpu_treeshap.cc
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
/**
|
||||||
|
* Copyright by XGBoost Contributors 2017-2022
|
||||||
|
*/
|
||||||
|
#include "cpu_treeshap.h"
|
||||||
|
|
||||||
|
#include <cinttypes> // std::uint32_t
|
||||||
|
|
||||||
|
#include "predict_fn.h" // GetNextNode
|
||||||
|
#include "xgboost/base.h" // bst_node_t
|
||||||
|
#include "xgboost/logging.h"
|
||||||
|
#include "xgboost/tree_model.h" // RegTree
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
// Used by TreeShap
|
||||||
|
// data we keep about our decision path
|
||||||
|
// note that pweight is included for convenience and is not tied with the other attributes
|
||||||
|
// the pweight of the i'th path element is the permutation weight of paths with i-1 ones in them
|
||||||
|
struct PathElement {
|
||||||
|
int feature_index;
|
||||||
|
float zero_fraction;
|
||||||
|
float one_fraction;
|
||||||
|
float pweight;
|
||||||
|
PathElement() = default;
|
||||||
|
PathElement(int i, float z, float o, float w)
|
||||||
|
: feature_index(i), zero_fraction(z), one_fraction(o), pweight(w) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
// extend our decision path with a fraction of one and zero extensions
|
||||||
|
void ExtendPath(PathElement* unique_path, std::uint32_t unique_depth, float zero_fraction,
|
||||||
|
float one_fraction, int feature_index) {
|
||||||
|
unique_path[unique_depth].feature_index = feature_index;
|
||||||
|
unique_path[unique_depth].zero_fraction = zero_fraction;
|
||||||
|
unique_path[unique_depth].one_fraction = one_fraction;
|
||||||
|
unique_path[unique_depth].pweight = (unique_depth == 0 ? 1.0f : 0.0f);
|
||||||
|
for (int i = unique_depth - 1; i >= 0; i--) {
|
||||||
|
unique_path[i + 1].pweight +=
|
||||||
|
one_fraction * unique_path[i].pweight * (i + 1) / static_cast<float>(unique_depth + 1);
|
||||||
|
unique_path[i].pweight = zero_fraction * unique_path[i].pweight * (unique_depth - i) /
|
||||||
|
static_cast<float>(unique_depth + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// undo a previous extension of the decision path
|
||||||
|
void UnwindPath(PathElement* unique_path, std::uint32_t unique_depth, std::uint32_t path_index) {
|
||||||
|
const float one_fraction = unique_path[path_index].one_fraction;
|
||||||
|
const float zero_fraction = unique_path[path_index].zero_fraction;
|
||||||
|
float next_one_portion = unique_path[unique_depth].pweight;
|
||||||
|
|
||||||
|
for (int i = unique_depth - 1; i >= 0; --i) {
|
||||||
|
if (one_fraction != 0) {
|
||||||
|
const float tmp = unique_path[i].pweight;
|
||||||
|
unique_path[i].pweight =
|
||||||
|
next_one_portion * (unique_depth + 1) / static_cast<float>((i + 1) * one_fraction);
|
||||||
|
next_one_portion = tmp - unique_path[i].pweight * zero_fraction * (unique_depth - i) /
|
||||||
|
static_cast<float>(unique_depth + 1);
|
||||||
|
} else {
|
||||||
|
unique_path[i].pweight = (unique_path[i].pweight * (unique_depth + 1)) /
|
||||||
|
static_cast<float>(zero_fraction * (unique_depth - i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto i = path_index; i < unique_depth; ++i) {
|
||||||
|
unique_path[i].feature_index = unique_path[i + 1].feature_index;
|
||||||
|
unique_path[i].zero_fraction = unique_path[i + 1].zero_fraction;
|
||||||
|
unique_path[i].one_fraction = unique_path[i + 1].one_fraction;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// determine what the total permutation weight would be if
|
||||||
|
// we unwound a previous extension in the decision path
|
||||||
|
float UnwoundPathSum(const PathElement* unique_path, std::uint32_t unique_depth,
|
||||||
|
std::uint32_t path_index) {
|
||||||
|
const float one_fraction = unique_path[path_index].one_fraction;
|
||||||
|
const float zero_fraction = unique_path[path_index].zero_fraction;
|
||||||
|
float next_one_portion = unique_path[unique_depth].pweight;
|
||||||
|
float total = 0;
|
||||||
|
for (int i = unique_depth - 1; i >= 0; --i) {
|
||||||
|
if (one_fraction != 0) {
|
||||||
|
const float tmp =
|
||||||
|
next_one_portion * (unique_depth + 1) / static_cast<float>((i + 1) * one_fraction);
|
||||||
|
total += tmp;
|
||||||
|
next_one_portion =
|
||||||
|
unique_path[i].pweight -
|
||||||
|
tmp * zero_fraction * ((unique_depth - i) / static_cast<float>(unique_depth + 1));
|
||||||
|
} else if (zero_fraction != 0) {
|
||||||
|
total += (unique_path[i].pweight / zero_fraction) /
|
||||||
|
((unique_depth - i) / static_cast<float>(unique_depth + 1));
|
||||||
|
} else {
|
||||||
|
CHECK_EQ(unique_path[i].pweight, 0) << "Unique path " << i << " must have zero weight";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return total;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Recursive function that computes the feature attributions for a single tree.
|
||||||
|
* \param feat dense feature vector, if the feature is missing the field is set to NaN
|
||||||
|
* \param phi dense output vector of feature attributions
|
||||||
|
* \param node_index the index of the current node in the tree
|
||||||
|
* \param unique_depth how many unique features are above the current node in the tree
|
||||||
|
* \param parent_unique_path a vector of statistics about our current path through the tree
|
||||||
|
* \param parent_zero_fraction what fraction of the parent path weight is coming as 0 (integrated)
|
||||||
|
* \param parent_one_fraction what fraction of the parent path weight is coming as 1 (fixed)
|
||||||
|
* \param parent_feature_index what feature the parent node used to split
|
||||||
|
* \param condition fix one feature to either off (-1) on (1) or not fixed (0 default)
|
||||||
|
* \param condition_feature the index of the feature to fix
|
||||||
|
* \param condition_fraction what fraction of the current weight matches our conditioning feature
|
||||||
|
*/
|
||||||
|
void TreeShap(RegTree const& tree, const RegTree::FVec& feat, float* phi, bst_node_t node_index,
|
||||||
|
std::uint32_t unique_depth, PathElement* parent_unique_path,
|
||||||
|
float parent_zero_fraction, float parent_one_fraction, int parent_feature_index,
|
||||||
|
int condition, std::uint32_t condition_feature, float condition_fraction) {
|
||||||
|
const auto node = tree[node_index];
|
||||||
|
|
||||||
|
// stop if we have no weight coming down to us
|
||||||
|
if (condition_fraction == 0) return;
|
||||||
|
|
||||||
|
// extend the unique path
|
||||||
|
PathElement* unique_path = parent_unique_path + unique_depth + 1;
|
||||||
|
std::copy(parent_unique_path, parent_unique_path + unique_depth + 1, unique_path);
|
||||||
|
|
||||||
|
if (condition == 0 || condition_feature != static_cast<std::uint32_t>(parent_feature_index)) {
|
||||||
|
ExtendPath(unique_path, unique_depth, parent_zero_fraction, parent_one_fraction,
|
||||||
|
parent_feature_index);
|
||||||
|
}
|
||||||
|
const std::uint32_t split_index = node.SplitIndex();
|
||||||
|
|
||||||
|
// leaf node
|
||||||
|
if (node.IsLeaf()) {
|
||||||
|
for (std::uint32_t i = 1; i <= unique_depth; ++i) {
|
||||||
|
const float w = UnwoundPathSum(unique_path, unique_depth, i);
|
||||||
|
const PathElement& el = unique_path[i];
|
||||||
|
phi[el.feature_index] +=
|
||||||
|
w * (el.one_fraction - el.zero_fraction) * node.LeafValue() * condition_fraction;
|
||||||
|
}
|
||||||
|
|
||||||
|
// internal node
|
||||||
|
} else {
|
||||||
|
// find which branch is "hot" (meaning x would follow it)
|
||||||
|
auto const& cats = tree.GetCategoriesMatrix();
|
||||||
|
bst_node_t hot_index = predictor::GetNextNode<true, true>(
|
||||||
|
node, node_index, feat.GetFvalue(split_index), feat.IsMissing(split_index), cats);
|
||||||
|
|
||||||
|
const auto cold_index = (hot_index == node.LeftChild() ? node.RightChild() : node.LeftChild());
|
||||||
|
const float w = tree.Stat(node_index).sum_hess;
|
||||||
|
const float hot_zero_fraction = tree.Stat(hot_index).sum_hess / w;
|
||||||
|
const float cold_zero_fraction = tree.Stat(cold_index).sum_hess / w;
|
||||||
|
float incoming_zero_fraction = 1;
|
||||||
|
float incoming_one_fraction = 1;
|
||||||
|
|
||||||
|
// see if we have already split on this feature,
|
||||||
|
// if so we undo that split so we can redo it for this node
|
||||||
|
std::uint32_t path_index = 0;
|
||||||
|
for (; path_index <= unique_depth; ++path_index) {
|
||||||
|
if (static_cast<std::uint32_t>(unique_path[path_index].feature_index) == split_index) break;
|
||||||
|
}
|
||||||
|
if (path_index != unique_depth + 1) {
|
||||||
|
incoming_zero_fraction = unique_path[path_index].zero_fraction;
|
||||||
|
incoming_one_fraction = unique_path[path_index].one_fraction;
|
||||||
|
UnwindPath(unique_path, unique_depth, path_index);
|
||||||
|
unique_depth -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// divide up the condition_fraction among the recursive calls
|
||||||
|
float hot_condition_fraction = condition_fraction;
|
||||||
|
float cold_condition_fraction = condition_fraction;
|
||||||
|
if (condition > 0 && split_index == condition_feature) {
|
||||||
|
cold_condition_fraction = 0;
|
||||||
|
unique_depth -= 1;
|
||||||
|
} else if (condition < 0 && split_index == condition_feature) {
|
||||||
|
hot_condition_fraction *= hot_zero_fraction;
|
||||||
|
cold_condition_fraction *= cold_zero_fraction;
|
||||||
|
unique_depth -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
TreeShap(tree, feat, phi, hot_index, unique_depth + 1, unique_path,
|
||||||
|
hot_zero_fraction * incoming_zero_fraction, incoming_one_fraction, split_index,
|
||||||
|
condition, condition_feature, hot_condition_fraction);
|
||||||
|
|
||||||
|
TreeShap(tree, feat, phi, cold_index, unique_depth + 1, unique_path,
|
||||||
|
cold_zero_fraction * incoming_zero_fraction, 0, split_index, condition,
|
||||||
|
condition_feature, cold_condition_fraction);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void CalculateContributions(RegTree const& tree, const RegTree::FVec& feat,
|
||||||
|
std::vector<float>* mean_values, float* out_contribs, int condition,
|
||||||
|
std::uint32_t condition_feature) {
|
||||||
|
// find the expected value of the tree's predictions
|
||||||
|
if (condition == 0) {
|
||||||
|
float node_value = (*mean_values)[0];
|
||||||
|
out_contribs[feat.Size()] += node_value;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Preallocate space for the unique path data
|
||||||
|
const int maxd = tree.MaxDepth(0) + 2;
|
||||||
|
std::vector<PathElement> unique_path_data((maxd * (maxd + 1)) / 2);
|
||||||
|
|
||||||
|
TreeShap(tree, feat, out_contribs, 0, 0, unique_path_data.data(), 1, 1, -1, condition,
|
||||||
|
condition_feature, 1);
|
||||||
|
}
|
||||||
|
} // namespace xgboost
|
||||||
17
src/predictor/cpu_treeshap.h
Normal file
17
src/predictor/cpu_treeshap.h
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
/**
|
||||||
|
* Copyright by XGBoost Contributors 2017-2022
|
||||||
|
*/
|
||||||
|
#include "xgboost/tree_model.h" // RegTree
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
/**
|
||||||
|
* \brief calculate the feature contributions (https://arxiv.org/abs/1706.06060) for the tree
|
||||||
|
* \param feat dense feature vector, if the feature is missing the field is set to NaN
|
||||||
|
* \param out_contribs output vector to hold the contributions
|
||||||
|
* \param condition fix one feature to either off (-1) on (1) or not fixed (0 default)
|
||||||
|
* \param condition_feature the index of the feature to fix
|
||||||
|
*/
|
||||||
|
void CalculateContributions(RegTree const &tree, const RegTree::FVec &feat,
|
||||||
|
std::vector<float> *mean_values, bst_float *out_contribs, int condition,
|
||||||
|
unsigned condition_feature);
|
||||||
|
} // namespace xgboost
|
||||||
@ -1248,189 +1248,4 @@ void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,
|
|||||||
// update leaf feature weight
|
// update leaf feature weight
|
||||||
out_contribs[split_index] += leaf_value - node_value;
|
out_contribs[split_index] += leaf_value - node_value;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Used by TreeShap
|
|
||||||
// data we keep about our decision path
|
|
||||||
// note that pweight is included for convenience and is not tied with the other attributes
|
|
||||||
// the pweight of the i'th path element is the permutation weight of paths with i-1 ones in them
|
|
||||||
struct PathElement {
|
|
||||||
int feature_index;
|
|
||||||
bst_float zero_fraction;
|
|
||||||
bst_float one_fraction;
|
|
||||||
bst_float pweight;
|
|
||||||
PathElement() = default;
|
|
||||||
PathElement(int i, bst_float z, bst_float o, bst_float w) :
|
|
||||||
feature_index(i), zero_fraction(z), one_fraction(o), pweight(w) {}
|
|
||||||
};
|
|
||||||
|
|
||||||
// extend our decision path with a fraction of one and zero extensions
|
|
||||||
void ExtendPath(PathElement *unique_path, unsigned unique_depth,
|
|
||||||
bst_float zero_fraction, bst_float one_fraction,
|
|
||||||
int feature_index) {
|
|
||||||
unique_path[unique_depth].feature_index = feature_index;
|
|
||||||
unique_path[unique_depth].zero_fraction = zero_fraction;
|
|
||||||
unique_path[unique_depth].one_fraction = one_fraction;
|
|
||||||
unique_path[unique_depth].pweight = (unique_depth == 0 ? 1.0f : 0.0f);
|
|
||||||
for (int i = unique_depth - 1; i >= 0; i--) {
|
|
||||||
unique_path[i+1].pweight += one_fraction * unique_path[i].pweight * (i + 1)
|
|
||||||
/ static_cast<bst_float>(unique_depth + 1);
|
|
||||||
unique_path[i].pweight = zero_fraction * unique_path[i].pweight * (unique_depth - i)
|
|
||||||
/ static_cast<bst_float>(unique_depth + 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// undo a previous extension of the decision path
|
|
||||||
void UnwindPath(PathElement *unique_path, unsigned unique_depth,
|
|
||||||
unsigned path_index) {
|
|
||||||
const bst_float one_fraction = unique_path[path_index].one_fraction;
|
|
||||||
const bst_float zero_fraction = unique_path[path_index].zero_fraction;
|
|
||||||
bst_float next_one_portion = unique_path[unique_depth].pweight;
|
|
||||||
|
|
||||||
for (int i = unique_depth - 1; i >= 0; --i) {
|
|
||||||
if (one_fraction != 0) {
|
|
||||||
const bst_float tmp = unique_path[i].pweight;
|
|
||||||
unique_path[i].pweight = next_one_portion * (unique_depth + 1)
|
|
||||||
/ static_cast<bst_float>((i + 1) * one_fraction);
|
|
||||||
next_one_portion = tmp - unique_path[i].pweight * zero_fraction * (unique_depth - i)
|
|
||||||
/ static_cast<bst_float>(unique_depth + 1);
|
|
||||||
} else {
|
|
||||||
unique_path[i].pweight = (unique_path[i].pweight * (unique_depth + 1))
|
|
||||||
/ static_cast<bst_float>(zero_fraction * (unique_depth - i));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto i = path_index; i < unique_depth; ++i) {
|
|
||||||
unique_path[i].feature_index = unique_path[i+1].feature_index;
|
|
||||||
unique_path[i].zero_fraction = unique_path[i+1].zero_fraction;
|
|
||||||
unique_path[i].one_fraction = unique_path[i+1].one_fraction;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// determine what the total permutation weight would be if
|
|
||||||
// we unwound a previous extension in the decision path
|
|
||||||
bst_float UnwoundPathSum(const PathElement *unique_path, unsigned unique_depth,
|
|
||||||
unsigned path_index) {
|
|
||||||
const bst_float one_fraction = unique_path[path_index].one_fraction;
|
|
||||||
const bst_float zero_fraction = unique_path[path_index].zero_fraction;
|
|
||||||
bst_float next_one_portion = unique_path[unique_depth].pweight;
|
|
||||||
bst_float total = 0;
|
|
||||||
for (int i = unique_depth - 1; i >= 0; --i) {
|
|
||||||
if (one_fraction != 0) {
|
|
||||||
const bst_float tmp = next_one_portion * (unique_depth + 1)
|
|
||||||
/ static_cast<bst_float>((i + 1) * one_fraction);
|
|
||||||
total += tmp;
|
|
||||||
next_one_portion = unique_path[i].pweight - tmp * zero_fraction * ((unique_depth - i)
|
|
||||||
/ static_cast<bst_float>(unique_depth + 1));
|
|
||||||
} else if (zero_fraction != 0) {
|
|
||||||
total += (unique_path[i].pweight / zero_fraction) / ((unique_depth - i)
|
|
||||||
/ static_cast<bst_float>(unique_depth + 1));
|
|
||||||
} else {
|
|
||||||
CHECK_EQ(unique_path[i].pweight, 0)
|
|
||||||
<< "Unique path " << i << " must have zero weight";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return total;
|
|
||||||
}
|
|
||||||
|
|
||||||
// recursive computation of SHAP values for a decision tree
|
|
||||||
void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
|
|
||||||
bst_node_t node_index, unsigned unique_depth,
|
|
||||||
PathElement *parent_unique_path,
|
|
||||||
bst_float parent_zero_fraction,
|
|
||||||
bst_float parent_one_fraction, int parent_feature_index,
|
|
||||||
int condition, unsigned condition_feature,
|
|
||||||
bst_float condition_fraction) const {
|
|
||||||
const auto node = (*this)[node_index];
|
|
||||||
|
|
||||||
// stop if we have no weight coming down to us
|
|
||||||
if (condition_fraction == 0) return;
|
|
||||||
|
|
||||||
// extend the unique path
|
|
||||||
PathElement *unique_path = parent_unique_path + unique_depth + 1;
|
|
||||||
std::copy(parent_unique_path, parent_unique_path + unique_depth + 1, unique_path);
|
|
||||||
|
|
||||||
if (condition == 0 || condition_feature != static_cast<unsigned>(parent_feature_index)) {
|
|
||||||
ExtendPath(unique_path, unique_depth, parent_zero_fraction,
|
|
||||||
parent_one_fraction, parent_feature_index);
|
|
||||||
}
|
|
||||||
const unsigned split_index = node.SplitIndex();
|
|
||||||
|
|
||||||
// leaf node
|
|
||||||
if (node.IsLeaf()) {
|
|
||||||
for (unsigned i = 1; i <= unique_depth; ++i) {
|
|
||||||
const bst_float w = UnwoundPathSum(unique_path, unique_depth, i);
|
|
||||||
const PathElement &el = unique_path[i];
|
|
||||||
phi[el.feature_index] += w * (el.one_fraction - el.zero_fraction)
|
|
||||||
* node.LeafValue() * condition_fraction;
|
|
||||||
}
|
|
||||||
|
|
||||||
// internal node
|
|
||||||
} else {
|
|
||||||
// find which branch is "hot" (meaning x would follow it)
|
|
||||||
auto const &cats = this->GetCategoriesMatrix();
|
|
||||||
bst_node_t hot_index = predictor::GetNextNode<true, true>(
|
|
||||||
node, node_index, feat.GetFvalue(split_index),
|
|
||||||
feat.IsMissing(split_index), cats);
|
|
||||||
|
|
||||||
const auto cold_index =
|
|
||||||
(hot_index == node.LeftChild() ? node.RightChild() : node.LeftChild());
|
|
||||||
const bst_float w = this->Stat(node_index).sum_hess;
|
|
||||||
const bst_float hot_zero_fraction = this->Stat(hot_index).sum_hess / w;
|
|
||||||
const bst_float cold_zero_fraction = this->Stat(cold_index).sum_hess / w;
|
|
||||||
bst_float incoming_zero_fraction = 1;
|
|
||||||
bst_float incoming_one_fraction = 1;
|
|
||||||
|
|
||||||
// see if we have already split on this feature,
|
|
||||||
// if so we undo that split so we can redo it for this node
|
|
||||||
unsigned path_index = 0;
|
|
||||||
for (; path_index <= unique_depth; ++path_index) {
|
|
||||||
if (static_cast<unsigned>(unique_path[path_index].feature_index) == split_index) break;
|
|
||||||
}
|
|
||||||
if (path_index != unique_depth + 1) {
|
|
||||||
incoming_zero_fraction = unique_path[path_index].zero_fraction;
|
|
||||||
incoming_one_fraction = unique_path[path_index].one_fraction;
|
|
||||||
UnwindPath(unique_path, unique_depth, path_index);
|
|
||||||
unique_depth -= 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// divide up the condition_fraction among the recursive calls
|
|
||||||
bst_float hot_condition_fraction = condition_fraction;
|
|
||||||
bst_float cold_condition_fraction = condition_fraction;
|
|
||||||
if (condition > 0 && split_index == condition_feature) {
|
|
||||||
cold_condition_fraction = 0;
|
|
||||||
unique_depth -= 1;
|
|
||||||
} else if (condition < 0 && split_index == condition_feature) {
|
|
||||||
hot_condition_fraction *= hot_zero_fraction;
|
|
||||||
cold_condition_fraction *= cold_zero_fraction;
|
|
||||||
unique_depth -= 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
TreeShap(feat, phi, hot_index, unique_depth + 1, unique_path,
|
|
||||||
hot_zero_fraction * incoming_zero_fraction, incoming_one_fraction,
|
|
||||||
split_index, condition, condition_feature, hot_condition_fraction);
|
|
||||||
|
|
||||||
TreeShap(feat, phi, cold_index, unique_depth + 1, unique_path,
|
|
||||||
cold_zero_fraction * incoming_zero_fraction, 0,
|
|
||||||
split_index, condition, condition_feature, cold_condition_fraction);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void RegTree::CalculateContributions(const RegTree::FVec &feat,
|
|
||||||
std::vector<float>* mean_values,
|
|
||||||
bst_float *out_contribs,
|
|
||||||
int condition,
|
|
||||||
unsigned condition_feature) const {
|
|
||||||
// find the expected value of the tree's predictions
|
|
||||||
if (condition == 0) {
|
|
||||||
bst_float node_value = (*mean_values)[0];
|
|
||||||
out_contribs[feat.Size()] += node_value;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Preallocate space for the unique path data
|
|
||||||
const int maxd = this->MaxDepth(0) + 2;
|
|
||||||
std::vector<PathElement> unique_path_data((maxd * (maxd + 1)) / 2);
|
|
||||||
|
|
||||||
TreeShap(feat, out_contribs, 0, 0, unique_path_data.data(),
|
|
||||||
1, 1, -1, condition, condition_feature, 1);
|
|
||||||
}
|
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user