Implement categorical data support for SHAP. (#7053)

* Add CPU implementation.
* Update GPUTreeSHAP.
* Add GPU implementation by defining custom split condition.
This commit is contained in:
Jiaming Yuan
2021-06-25 19:02:46 +08:00
committed by GitHub
parent 663136aa08
commit 8fa32fdda2
12 changed files with 287 additions and 50 deletions

View File

@@ -1245,7 +1245,7 @@ bst_float UnwoundPathSum(const PathElement *unique_path, unsigned unique_depth,
// recursive computation of SHAP values for a decision tree
void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
unsigned node_index, unsigned unique_depth,
bst_node_t node_index, unsigned unique_depth,
PathElement *parent_unique_path,
bst_float parent_zero_fraction,
bst_float parent_one_fraction, int parent_feature_index,
@@ -1278,16 +1278,13 @@ void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
// internal node
} else {
// find which branch is "hot" (meaning x would follow it)
unsigned hot_index = 0;
if (feat.IsMissing(split_index)) {
hot_index = node.DefaultChild();
} else if (feat.GetFvalue(split_index) < node.SplitCond()) {
hot_index = node.LeftChild();
} else {
hot_index = node.RightChild();
}
const unsigned cold_index = (static_cast<int>(hot_index) == node.LeftChild() ?
node.RightChild() : node.LeftChild());
auto const &cats = this->GetCategoriesMatrix();
bst_node_t hot_index = predictor::GetNextNode<true, true>(
node, node_index, feat.GetFvalue(split_index),
feat.IsMissing(split_index), cats);
const auto cold_index =
(hot_index == node.LeftChild() ? node.RightChild() : node.LeftChild());
const bst_float w = this->Stat(node_index).sum_hess;
const bst_float hot_zero_fraction = this->Stat(hot_index).sum_hess / w;
const bst_float cold_zero_fraction = this->Stat(cold_index).sum_hess / w;