Split up SHAP from RegTree. (#8612)

* Split up SHAP from `RegTree`.

Simplify the tree interface.
This commit is contained in:
Jiaming Yuan
2023-01-04 18:17:47 +08:00
committed by GitHub
parent d308124910
commit beefd28471
7 changed files with 225 additions and 220 deletions

View File

@@ -217,7 +217,7 @@ class RegTree : public Model {
* \param default_left the default direction when feature is unknown
*/
XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond,
bool default_left = false) {
bool default_left = false) {
if (default_left) split_index |= (1U << 31);
this->sindex_ = split_index;
(this->info_).split_cond = split_cond;
@@ -542,37 +542,6 @@ class RegTree : public Model {
bool has_missing_;
};
/*!
* \brief calculate the feature contributions (https://arxiv.org/abs/1706.06060) for the tree
* \param feat dense feature vector, if the feature is missing the field is set to NaN
* \param out_contribs output vector to hold the contributions
* \param condition fix one feature to either off (-1) on (1) or not fixed (0 default)
* \param condition_feature the index of the feature to fix
*/
void CalculateContributions(const RegTree::FVec& feat,
std::vector<float>* mean_values,
bst_float* out_contribs, int condition = 0,
unsigned condition_feature = 0) const;
/*!
* \brief Recursive function that computes the feature attributions for a single tree.
* \param feat dense feature vector, if the feature is missing the field is set to NaN
* \param phi dense output vector of feature attributions
* \param node_index the index of the current node in the tree
* \param unique_depth how many unique features are above the current node in the tree
* \param parent_unique_path a vector of statistics about our current path through the tree
* \param parent_zero_fraction what fraction of the parent path weight is coming as 0 (integrated)
* \param parent_one_fraction what fraction of the parent path weight is coming as 1 (fixed)
* \param parent_feature_index what feature the parent node used to split
* \param condition fix one feature to either off (-1) on (1) or not fixed (0 default)
* \param condition_feature the index of the feature to fix
* \param condition_fraction what fraction of the current weight matches our conditioning feature
*/
void TreeShap(const RegTree::FVec& feat, bst_float* phi, bst_node_t node_index,
unsigned unique_depth, PathElement* parent_unique_path,
bst_float parent_zero_fraction, bst_float parent_one_fraction,
int parent_feature_index, int condition,
unsigned condition_feature, bst_float condition_fraction) const;
/*!
* \brief calculate the approximate feature contributions for the given root
* \param feat dense feature vector, if the feature is missing the field is set to NaN