Fix race condition in CPU shap. (#7050)

This commit is contained in:
Jiaming Yuan 2021-06-21 10:03:15 +08:00 committed by GitHub
parent 29f8fd6fee
commit bbfffb444d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 68 additions and 64 deletions

View File

@ -206,22 +206,18 @@ class Predictor {
* \param condition_feature Feature to condition on (i.e. fix) during calculations. * \param condition_feature Feature to condition on (i.e. fix) during calculations.
*/ */
virtual void PredictContribution(DMatrix* dmat, virtual void
HostDeviceVector<bst_float>* out_contribs, PredictContribution(DMatrix *dmat, HostDeviceVector<bst_float> *out_contribs,
const gbm::GBTreeModel& model, const gbm::GBTreeModel &model, unsigned tree_end = 0,
unsigned tree_end = 0, std::vector<bst_float> const *tree_weights = nullptr,
std::vector<bst_float>* tree_weights = nullptr, bool approximate = false, int condition = 0,
bool approximate = false, unsigned condition_feature = 0) const = 0;
int condition = 0,
unsigned condition_feature = 0) const = 0;
virtual void PredictInteractionContributions(DMatrix* dmat,
HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model,
unsigned tree_end = 0,
std::vector<bst_float>* tree_weights = nullptr,
bool approximate = false) const = 0;
virtual void PredictInteractionContributions(
DMatrix *dmat, HostDeviceVector<bst_float> *out_contribs,
const gbm::GBTreeModel &model, unsigned tree_end = 0,
std::vector<bst_float> const *tree_weights = nullptr,
bool approximate = false) const = 0;
/** /**
* \brief Creates a new Predictor*. * \brief Creates a new Predictor*.

View File

@ -550,6 +550,7 @@ class RegTree : public Model {
* \param condition_feature the index of the feature to fix * \param condition_feature the index of the feature to fix
*/ */
void CalculateContributions(const RegTree::FVec& feat, void CalculateContributions(const RegTree::FVec& feat,
std::vector<float>* mean_values,
bst_float* out_contribs, int condition = 0, bst_float* out_contribs, int condition = 0,
unsigned condition_feature = 0) const; unsigned condition_feature = 0) const;
/*! /*!
@ -578,6 +579,7 @@ class RegTree : public Model {
* \param out_contribs output vector to hold the contributions * \param out_contribs output vector to hold the contributions
*/ */
void CalculateContributionsApprox(const RegTree::FVec& feat, void CalculateContributionsApprox(const RegTree::FVec& feat,
std::vector<float>* mean_values,
bst_float* out_contribs) const; bst_float* out_contribs) const;
/*! /*!
* \brief dump the model in the requested format as a text string * \brief dump the model in the requested format as a text string
@ -589,10 +591,6 @@ class RegTree : public Model {
std::string DumpModel(const FeatureMap& fmap, std::string DumpModel(const FeatureMap& fmap,
bool with_stats, bool with_stats,
std::string format) const; std::string format) const;
/*!
* \brief calculate the mean value for each node, required for feature contributions
*/
void FillNodeMeanValues();
/*! /*!
* \brief Get split type for a node. * \brief Get split type for a node.
* \param nidx Index of node. * \param nidx Index of node.
@ -639,7 +637,6 @@ class RegTree : public Model {
std::vector<int> deleted_nodes_; std::vector<int> deleted_nodes_;
// stats of nodes // stats of nodes
std::vector<RTreeNodeStat> stats_; std::vector<RTreeNodeStat> stats_;
std::vector<bst_float> node_mean_values_;
std::vector<FeatureType> split_types_; std::vector<FeatureType> split_types_;
// Categories for each internal node. // Categories for each internal node.
@ -680,7 +677,6 @@ class RegTree : public Model {
nodes_[nid].MarkDelete(); nodes_[nid].MarkDelete();
++param.num_deleted; ++param.num_deleted;
} }
bst_float FillNodeMeanValue(int nid);
}; };
inline void RegTree::FVec::Init(size_t size) { inline void RegTree::FVec::Init(size_t size) {

View File

@ -213,6 +213,32 @@ void PredictBatchByBlockOfRowsKernel(
}); });
} }
float FillNodeMeanValues(RegTree const *tree, bst_node_t nidx, std::vector<float> *mean_values) {
bst_float result;
auto &node = (*tree)[nidx];
auto &node_mean_values = *mean_values;
if (node.IsLeaf()) {
result = node.LeafValue();
} else {
result = FillNodeMeanValues(tree, node.LeftChild(), mean_values) *
tree->Stat(node.LeftChild()).sum_hess;
result += FillNodeMeanValues(tree, node.RightChild(), mean_values) *
tree->Stat(node.RightChild()).sum_hess;
result /= tree->Stat(nidx).sum_hess;
}
node_mean_values[nidx] = result;
return result;
}
void FillNodeMeanValues(RegTree const* tree, std::vector<float>* mean_values) {
size_t num_nodes = tree->param.num_nodes;
if (mean_values->size() == num_nodes) {
return;
}
mean_values->resize(num_nodes);
FillNodeMeanValues(tree, 0, mean_values);
}
class CPUPredictor : public Predictor { class CPUPredictor : public Predictor {
protected: protected:
// init thread buffers // init thread buffers
@ -396,9 +422,10 @@ class CPUPredictor : public Predictor {
} }
} }
void PredictContribution(DMatrix* p_fmat, HostDeviceVector<float>* out_contribs, void PredictContribution(DMatrix *p_fmat,
const gbm::GBTreeModel& model, uint32_t ntree_limit, HostDeviceVector<float> *out_contribs,
std::vector<bst_float>* tree_weights, const gbm::GBTreeModel &model, uint32_t ntree_limit,
std::vector<bst_float> const *tree_weights,
bool approximate, int condition, bool approximate, int condition,
unsigned condition_feature) const override { unsigned condition_feature) const override {
const int nthread = omp_get_max_threads(); const int nthread = omp_get_max_threads();
@ -421,8 +448,9 @@ class CPUPredictor : public Predictor {
// allocated one // allocated one
std::fill(contribs.begin(), contribs.end(), 0); std::fill(contribs.begin(), contribs.end(), 0);
// initialize tree node mean values // initialize tree node mean values
std::vector<std::vector<float>> mean_values(ntree_limit);
common::ParallelFor(bst_omp_uint(ntree_limit), [&](bst_omp_uint i) { common::ParallelFor(bst_omp_uint(ntree_limit), [&](bst_omp_uint i) {
model.trees[i]->FillNodeMeanValues(); FillNodeMeanValues(model.trees[i].get(), &(mean_values[i]));
}); });
const std::vector<bst_float>& base_margin = info.base_margin_.HostVector(); const std::vector<bst_float>& base_margin = info.base_margin_.HostVector();
// start collecting the contributions // start collecting the contributions
@ -443,19 +471,23 @@ class CPUPredictor : public Predictor {
feats.Fill(page[i]); feats.Fill(page[i]);
// calculate contributions // calculate contributions
for (unsigned j = 0; j < ntree_limit; ++j) { for (unsigned j = 0; j < ntree_limit; ++j) {
auto *tree_mean_values = &mean_values.at(j);
std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0); std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0);
if (model.tree_info[j] != gid) { if (model.tree_info[j] != gid) {
continue; continue;
} }
if (!approximate) { if (!approximate) {
model.trees[j]->CalculateContributions(feats, &this_tree_contribs[0], model.trees[j]->CalculateContributions(
condition, condition_feature); feats, tree_mean_values, &this_tree_contribs[0], condition,
condition_feature);
} else { } else {
model.trees[j]->CalculateContributionsApprox(feats, &this_tree_contribs[0]); model.trees[j]->CalculateContributionsApprox(
feats, tree_mean_values, &this_tree_contribs[0]);
} }
for (size_t ci = 0 ; ci < ncolumns ; ++ci) { for (size_t ci = 0; ci < ncolumns; ++ci) {
p_contribs[ci] += this_tree_contribs[ci] * p_contribs[ci] +=
(tree_weights == nullptr ? 1 : (*tree_weights)[j]); this_tree_contribs[ci] *
(tree_weights == nullptr ? 1 : (*tree_weights)[j]);
} }
} }
feats.Drop(page[i]); feats.Drop(page[i]);
@ -470,10 +502,11 @@ class CPUPredictor : public Predictor {
} }
} }
void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_contribs, void PredictInteractionContributions(
const gbm::GBTreeModel& model, unsigned ntree_limit, DMatrix *p_fmat, HostDeviceVector<bst_float> *out_contribs,
std::vector<bst_float>* tree_weights, const gbm::GBTreeModel &model, unsigned ntree_limit,
bool approximate) const override { std::vector<bst_float> const *tree_weights,
bool approximate) const override {
const MetaInfo& info = p_fmat->Info(); const MetaInfo& info = p_fmat->Info();
const int ngroup = model.learner_model_param->num_output_group; const int ngroup = model.learner_model_param->num_output_group;
size_t const ncolumns = model.learner_model_param->num_feature; size_t const ncolumns = model.learner_model_param->num_feature;

View File

@ -696,7 +696,7 @@ class GPUPredictor : public xgboost::Predictor {
void PredictContribution(DMatrix* p_fmat, void PredictContribution(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_contribs, HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, unsigned tree_end, const gbm::GBTreeModel& model, unsigned tree_end,
std::vector<bst_float>*, std::vector<bst_float> const*,
bool approximate, int, bool approximate, int,
unsigned) const override { unsigned) const override {
if (approximate) { if (approximate) {
@ -746,7 +746,7 @@ class GPUPredictor : public xgboost::Predictor {
HostDeviceVector<bst_float>* out_contribs, HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, const gbm::GBTreeModel& model,
unsigned tree_end, unsigned tree_end,
std::vector<bst_float>*, std::vector<bst_float> const*,
bool approximate) const override { bool approximate) const override {
if (approximate) { if (approximate) {
LOG(FATAL) << "[Internal error]: " << __func__ LOG(FATAL) << "[Internal error]: " << __func__

View File

@ -1128,36 +1128,14 @@ void RegTree::SaveModel(Json* p_out) const {
out["default_left"] = std::move(default_left); out["default_left"] = std::move(default_left);
} }
void RegTree::FillNodeMeanValues() {
size_t num_nodes = this->param.num_nodes;
if (this->node_mean_values_.size() == num_nodes) {
return;
}
this->node_mean_values_.resize(num_nodes);
this->FillNodeMeanValue(0);
}
bst_float RegTree::FillNodeMeanValue(int nid) {
bst_float result;
auto& node = (*this)[nid];
if (node.IsLeaf()) {
result = node.LeafValue();
} else {
result = this->FillNodeMeanValue(node.LeftChild()) * this->Stat(node.LeftChild()).sum_hess;
result += this->FillNodeMeanValue(node.RightChild()) * this->Stat(node.RightChild()).sum_hess;
result /= this->Stat(nid).sum_hess;
}
this->node_mean_values_[nid] = result;
return result;
}
void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat, void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,
std::vector<float>* mean_values,
bst_float *out_contribs) const { bst_float *out_contribs) const {
CHECK_GT(this->node_mean_values_.size(), 0U); CHECK_GT(mean_values->size(), 0U);
// this follows the idea of http://blog.datadive.net/interpreting-random-forests/ // this follows the idea of http://blog.datadive.net/interpreting-random-forests/
unsigned split_index = 0; unsigned split_index = 0;
// update bias value // update bias value
bst_float node_value = this->node_mean_values_[0]; bst_float node_value = (*mean_values)[0];
out_contribs[feat.Size()] += node_value; out_contribs[feat.Size()] += node_value;
if ((*this)[0].IsLeaf()) { if ((*this)[0].IsLeaf()) {
// nothing to do anymore // nothing to do anymore
@ -1172,7 +1150,7 @@ void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,
nid = predictor::GetNextNode<true, true>((*this)[nid], nid, nid = predictor::GetNextNode<true, true>((*this)[nid], nid,
feat.GetFvalue(split_index), feat.GetFvalue(split_index),
feat.IsMissing(split_index), cats); feat.IsMissing(split_index), cats);
bst_float new_value = this->node_mean_values_[nid]; bst_float new_value = (*mean_values)[nid];
// update feature weight // update feature weight
out_contribs[split_index] += new_value - node_value; out_contribs[split_index] += new_value - node_value;
node_value = new_value; node_value = new_value;
@ -1352,12 +1330,13 @@ void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
} }
void RegTree::CalculateContributions(const RegTree::FVec &feat, void RegTree::CalculateContributions(const RegTree::FVec &feat,
std::vector<float>* mean_values,
bst_float *out_contribs, bst_float *out_contribs,
int condition, int condition,
unsigned condition_feature) const { unsigned condition_feature) const {
// find the expected value of the tree's predictions // find the expected value of the tree's predictions
if (condition == 0) { if (condition == 0) {
bst_float node_value = this->node_mean_values_[0]; bst_float node_value = (*mean_values)[0];
out_contribs[feat.Size()] += node_value; out_contribs[feat.Size()] += node_value;
} }