Fix race condition in CPU shap. (#7050)
This commit is contained in:
@@ -1128,36 +1128,14 @@ void RegTree::SaveModel(Json* p_out) const {
|
||||
out["default_left"] = std::move(default_left);
|
||||
}
|
||||
|
||||
void RegTree::FillNodeMeanValues() {
|
||||
size_t num_nodes = this->param.num_nodes;
|
||||
if (this->node_mean_values_.size() == num_nodes) {
|
||||
return;
|
||||
}
|
||||
this->node_mean_values_.resize(num_nodes);
|
||||
this->FillNodeMeanValue(0);
|
||||
}
|
||||
|
||||
bst_float RegTree::FillNodeMeanValue(int nid) {
|
||||
bst_float result;
|
||||
auto& node = (*this)[nid];
|
||||
if (node.IsLeaf()) {
|
||||
result = node.LeafValue();
|
||||
} else {
|
||||
result = this->FillNodeMeanValue(node.LeftChild()) * this->Stat(node.LeftChild()).sum_hess;
|
||||
result += this->FillNodeMeanValue(node.RightChild()) * this->Stat(node.RightChild()).sum_hess;
|
||||
result /= this->Stat(nid).sum_hess;
|
||||
}
|
||||
this->node_mean_values_[nid] = result;
|
||||
return result;
|
||||
}
|
||||
|
||||
void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,
|
||||
std::vector<float>* mean_values,
|
||||
bst_float *out_contribs) const {
|
||||
CHECK_GT(this->node_mean_values_.size(), 0U);
|
||||
CHECK_GT(mean_values->size(), 0U);
|
||||
// this follows the idea of http://blog.datadive.net/interpreting-random-forests/
|
||||
unsigned split_index = 0;
|
||||
// update bias value
|
||||
bst_float node_value = this->node_mean_values_[0];
|
||||
bst_float node_value = (*mean_values)[0];
|
||||
out_contribs[feat.Size()] += node_value;
|
||||
if ((*this)[0].IsLeaf()) {
|
||||
// nothing to do anymore
|
||||
@@ -1172,7 +1150,7 @@ void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,
|
||||
nid = predictor::GetNextNode<true, true>((*this)[nid], nid,
|
||||
feat.GetFvalue(split_index),
|
||||
feat.IsMissing(split_index), cats);
|
||||
bst_float new_value = this->node_mean_values_[nid];
|
||||
bst_float new_value = (*mean_values)[nid];
|
||||
// update feature weight
|
||||
out_contribs[split_index] += new_value - node_value;
|
||||
node_value = new_value;
|
||||
@@ -1352,12 +1330,13 @@ void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
|
||||
}
|
||||
|
||||
void RegTree::CalculateContributions(const RegTree::FVec &feat,
|
||||
std::vector<float>* mean_values,
|
||||
bst_float *out_contribs,
|
||||
int condition,
|
||||
unsigned condition_feature) const {
|
||||
// find the expected value of the tree's predictions
|
||||
if (condition == 0) {
|
||||
bst_float node_value = this->node_mean_values_[0];
|
||||
bst_float node_value = (*mean_values)[0];
|
||||
out_contribs[feat.Size()] += node_value;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user