Implement categorical data support for SHAP. (#7053)
* Add CPU implementation. * Update GPUTreeSHAP. * Add GPU implementation by defining custom split condition.
This commit is contained in:
@@ -1245,7 +1245,7 @@ bst_float UnwoundPathSum(const PathElement *unique_path, unsigned unique_depth,
|
||||
|
||||
// recursive computation of SHAP values for a decision tree
|
||||
void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
|
||||
unsigned node_index, unsigned unique_depth,
|
||||
bst_node_t node_index, unsigned unique_depth,
|
||||
PathElement *parent_unique_path,
|
||||
bst_float parent_zero_fraction,
|
||||
bst_float parent_one_fraction, int parent_feature_index,
|
||||
@@ -1278,16 +1278,13 @@ void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
|
||||
// internal node
|
||||
} else {
|
||||
// find which branch is "hot" (meaning x would follow it)
|
||||
unsigned hot_index = 0;
|
||||
if (feat.IsMissing(split_index)) {
|
||||
hot_index = node.DefaultChild();
|
||||
} else if (feat.GetFvalue(split_index) < node.SplitCond()) {
|
||||
hot_index = node.LeftChild();
|
||||
} else {
|
||||
hot_index = node.RightChild();
|
||||
}
|
||||
const unsigned cold_index = (static_cast<int>(hot_index) == node.LeftChild() ?
|
||||
node.RightChild() : node.LeftChild());
|
||||
auto const &cats = this->GetCategoriesMatrix();
|
||||
bst_node_t hot_index = predictor::GetNextNode<true, true>(
|
||||
node, node_index, feat.GetFvalue(split_index),
|
||||
feat.IsMissing(split_index), cats);
|
||||
|
||||
const auto cold_index =
|
||||
(hot_index == node.LeftChild() ? node.RightChild() : node.LeftChild());
|
||||
const bst_float w = this->Stat(node_index).sum_hess;
|
||||
const bst_float hot_zero_fraction = this->Stat(hot_index).sum_hess / w;
|
||||
const bst_float cold_zero_fraction = this->Stat(cold_index).sum_hess / w;
|
||||
|
||||
Reference in New Issue
Block a user