Fix race condition in CPU shap. (#7050)

This commit is contained in:
Jiaming Yuan 2021-06-21 10:03:15 +08:00 committed by GitHub
parent 29f8fd6fee
commit bbfffb444d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 68 additions and 64 deletions

View File

@ -206,22 +206,18 @@ class Predictor {
* \param condition_feature Feature to condition on (i.e. fix) during calculations.
*/
virtual void PredictContribution(DMatrix* dmat,
HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model,
unsigned tree_end = 0,
std::vector<bst_float>* tree_weights = nullptr,
bool approximate = false,
int condition = 0,
unsigned condition_feature = 0) const = 0;
virtual void PredictInteractionContributions(DMatrix* dmat,
HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model,
unsigned tree_end = 0,
std::vector<bst_float>* tree_weights = nullptr,
bool approximate = false) const = 0;
virtual void
PredictContribution(DMatrix *dmat, HostDeviceVector<bst_float> *out_contribs,
const gbm::GBTreeModel &model, unsigned tree_end = 0,
std::vector<bst_float> const *tree_weights = nullptr,
bool approximate = false, int condition = 0,
unsigned condition_feature = 0) const = 0;
virtual void PredictInteractionContributions(
DMatrix *dmat, HostDeviceVector<bst_float> *out_contribs,
const gbm::GBTreeModel &model, unsigned tree_end = 0,
std::vector<bst_float> const *tree_weights = nullptr,
bool approximate = false) const = 0;
/**
* \brief Creates a new Predictor*.

View File

@ -550,6 +550,7 @@ class RegTree : public Model {
* \param condition_feature the index of the feature to fix
*/
void CalculateContributions(const RegTree::FVec& feat,
std::vector<float>* mean_values,
bst_float* out_contribs, int condition = 0,
unsigned condition_feature = 0) const;
/*!
@ -578,6 +579,7 @@ class RegTree : public Model {
* \param out_contribs output vector to hold the contributions
*/
void CalculateContributionsApprox(const RegTree::FVec& feat,
std::vector<float>* mean_values,
bst_float* out_contribs) const;
/*!
* \brief dump the model in the requested format as a text string
@ -589,10 +591,6 @@ class RegTree : public Model {
std::string DumpModel(const FeatureMap& fmap,
bool with_stats,
std::string format) const;
/*!
* \brief calculate the mean value for each node, required for feature contributions
*/
void FillNodeMeanValues();
/*!
* \brief Get split type for a node.
* \param nidx Index of node.
@ -639,7 +637,6 @@ class RegTree : public Model {
std::vector<int> deleted_nodes_;
// stats of nodes
std::vector<RTreeNodeStat> stats_;
std::vector<bst_float> node_mean_values_;
std::vector<FeatureType> split_types_;
// Categories for each internal node.
@ -680,7 +677,6 @@ class RegTree : public Model {
nodes_[nid].MarkDelete();
++param.num_deleted;
}
bst_float FillNodeMeanValue(int nid);
};
inline void RegTree::FVec::Init(size_t size) {

View File

@ -213,6 +213,32 @@ void PredictBatchByBlockOfRowsKernel(
});
}
float FillNodeMeanValues(RegTree const *tree, bst_node_t nidx, std::vector<float> *mean_values) {
bst_float result;
auto &node = (*tree)[nidx];
auto &node_mean_values = *mean_values;
if (node.IsLeaf()) {
result = node.LeafValue();
} else {
result = FillNodeMeanValues(tree, node.LeftChild(), mean_values) *
tree->Stat(node.LeftChild()).sum_hess;
result += FillNodeMeanValues(tree, node.RightChild(), mean_values) *
tree->Stat(node.RightChild()).sum_hess;
result /= tree->Stat(nidx).sum_hess;
}
node_mean_values[nidx] = result;
return result;
}
void FillNodeMeanValues(RegTree const* tree, std::vector<float>* mean_values) {
size_t num_nodes = tree->param.num_nodes;
if (mean_values->size() == num_nodes) {
return;
}
mean_values->resize(num_nodes);
FillNodeMeanValues(tree, 0, mean_values);
}
class CPUPredictor : public Predictor {
protected:
// init thread buffers
@ -396,9 +422,10 @@ class CPUPredictor : public Predictor {
}
}
void PredictContribution(DMatrix* p_fmat, HostDeviceVector<float>* out_contribs,
const gbm::GBTreeModel& model, uint32_t ntree_limit,
std::vector<bst_float>* tree_weights,
void PredictContribution(DMatrix *p_fmat,
HostDeviceVector<float> *out_contribs,
const gbm::GBTreeModel &model, uint32_t ntree_limit,
std::vector<bst_float> const *tree_weights,
bool approximate, int condition,
unsigned condition_feature) const override {
const int nthread = omp_get_max_threads();
@ -421,8 +448,9 @@ class CPUPredictor : public Predictor {
// allocated one
std::fill(contribs.begin(), contribs.end(), 0);
// initialize tree node mean values
std::vector<std::vector<float>> mean_values(ntree_limit);
common::ParallelFor(bst_omp_uint(ntree_limit), [&](bst_omp_uint i) {
model.trees[i]->FillNodeMeanValues();
FillNodeMeanValues(model.trees[i].get(), &(mean_values[i]));
});
const std::vector<bst_float>& base_margin = info.base_margin_.HostVector();
// start collecting the contributions
@ -443,19 +471,23 @@ class CPUPredictor : public Predictor {
feats.Fill(page[i]);
// calculate contributions
for (unsigned j = 0; j < ntree_limit; ++j) {
auto *tree_mean_values = &mean_values.at(j);
std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0);
if (model.tree_info[j] != gid) {
continue;
}
if (!approximate) {
model.trees[j]->CalculateContributions(feats, &this_tree_contribs[0],
condition, condition_feature);
model.trees[j]->CalculateContributions(
feats, tree_mean_values, &this_tree_contribs[0], condition,
condition_feature);
} else {
model.trees[j]->CalculateContributionsApprox(feats, &this_tree_contribs[0]);
model.trees[j]->CalculateContributionsApprox(
feats, tree_mean_values, &this_tree_contribs[0]);
}
for (size_t ci = 0 ; ci < ncolumns ; ++ci) {
p_contribs[ci] += this_tree_contribs[ci] *
(tree_weights == nullptr ? 1 : (*tree_weights)[j]);
for (size_t ci = 0; ci < ncolumns; ++ci) {
p_contribs[ci] +=
this_tree_contribs[ci] *
(tree_weights == nullptr ? 1 : (*tree_weights)[j]);
}
}
feats.Drop(page[i]);
@ -470,10 +502,11 @@ class CPUPredictor : public Predictor {
}
}
void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, unsigned ntree_limit,
std::vector<bst_float>* tree_weights,
bool approximate) const override {
void PredictInteractionContributions(
DMatrix *p_fmat, HostDeviceVector<bst_float> *out_contribs,
const gbm::GBTreeModel &model, unsigned ntree_limit,
std::vector<bst_float> const *tree_weights,
bool approximate) const override {
const MetaInfo& info = p_fmat->Info();
const int ngroup = model.learner_model_param->num_output_group;
size_t const ncolumns = model.learner_model_param->num_feature;

View File

@ -696,7 +696,7 @@ class GPUPredictor : public xgboost::Predictor {
void PredictContribution(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, unsigned tree_end,
std::vector<bst_float>*,
std::vector<bst_float> const*,
bool approximate, int,
unsigned) const override {
if (approximate) {
@ -746,7 +746,7 @@ class GPUPredictor : public xgboost::Predictor {
HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model,
unsigned tree_end,
std::vector<bst_float>*,
std::vector<bst_float> const*,
bool approximate) const override {
if (approximate) {
LOG(FATAL) << "[Internal error]: " << __func__

View File

@ -1128,36 +1128,14 @@ void RegTree::SaveModel(Json* p_out) const {
out["default_left"] = std::move(default_left);
}
void RegTree::FillNodeMeanValues() {
size_t num_nodes = this->param.num_nodes;
if (this->node_mean_values_.size() == num_nodes) {
return;
}
this->node_mean_values_.resize(num_nodes);
this->FillNodeMeanValue(0);
}
bst_float RegTree::FillNodeMeanValue(int nid) {
bst_float result;
auto& node = (*this)[nid];
if (node.IsLeaf()) {
result = node.LeafValue();
} else {
result = this->FillNodeMeanValue(node.LeftChild()) * this->Stat(node.LeftChild()).sum_hess;
result += this->FillNodeMeanValue(node.RightChild()) * this->Stat(node.RightChild()).sum_hess;
result /= this->Stat(nid).sum_hess;
}
this->node_mean_values_[nid] = result;
return result;
}
void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,
std::vector<float>* mean_values,
bst_float *out_contribs) const {
CHECK_GT(this->node_mean_values_.size(), 0U);
CHECK_GT(mean_values->size(), 0U);
// this follows the idea of http://blog.datadive.net/interpreting-random-forests/
unsigned split_index = 0;
// update bias value
bst_float node_value = this->node_mean_values_[0];
bst_float node_value = (*mean_values)[0];
out_contribs[feat.Size()] += node_value;
if ((*this)[0].IsLeaf()) {
// nothing to do anymore
@ -1172,7 +1150,7 @@ void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,
nid = predictor::GetNextNode<true, true>((*this)[nid], nid,
feat.GetFvalue(split_index),
feat.IsMissing(split_index), cats);
bst_float new_value = this->node_mean_values_[nid];
bst_float new_value = (*mean_values)[nid];
// update feature weight
out_contribs[split_index] += new_value - node_value;
node_value = new_value;
@ -1352,12 +1330,13 @@ void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
}
void RegTree::CalculateContributions(const RegTree::FVec &feat,
std::vector<float>* mean_values,
bst_float *out_contribs,
int condition,
unsigned condition_feature) const {
// find the expected value of the tree's predictions
if (condition == 0) {
bst_float node_value = this->node_mean_values_[0];
bst_float node_value = (*mean_values)[0];
out_contribs[feat.Size()] += node_value;
}