Fix race condition in CPU shap. (#7050)
This commit is contained in:
parent
29f8fd6fee
commit
bbfffb444d
@ -206,22 +206,18 @@ class Predictor {
|
||||
* \param condition_feature Feature to condition on (i.e. fix) during calculations.
|
||||
*/
|
||||
|
||||
virtual void PredictContribution(DMatrix* dmat,
|
||||
HostDeviceVector<bst_float>* out_contribs,
|
||||
const gbm::GBTreeModel& model,
|
||||
unsigned tree_end = 0,
|
||||
std::vector<bst_float>* tree_weights = nullptr,
|
||||
bool approximate = false,
|
||||
int condition = 0,
|
||||
unsigned condition_feature = 0) const = 0;
|
||||
|
||||
virtual void PredictInteractionContributions(DMatrix* dmat,
|
||||
HostDeviceVector<bst_float>* out_contribs,
|
||||
const gbm::GBTreeModel& model,
|
||||
unsigned tree_end = 0,
|
||||
std::vector<bst_float>* tree_weights = nullptr,
|
||||
bool approximate = false) const = 0;
|
||||
virtual void
|
||||
PredictContribution(DMatrix *dmat, HostDeviceVector<bst_float> *out_contribs,
|
||||
const gbm::GBTreeModel &model, unsigned tree_end = 0,
|
||||
std::vector<bst_float> const *tree_weights = nullptr,
|
||||
bool approximate = false, int condition = 0,
|
||||
unsigned condition_feature = 0) const = 0;
|
||||
|
||||
virtual void PredictInteractionContributions(
|
||||
DMatrix *dmat, HostDeviceVector<bst_float> *out_contribs,
|
||||
const gbm::GBTreeModel &model, unsigned tree_end = 0,
|
||||
std::vector<bst_float> const *tree_weights = nullptr,
|
||||
bool approximate = false) const = 0;
|
||||
|
||||
/**
|
||||
* \brief Creates a new Predictor*.
|
||||
|
||||
@ -550,6 +550,7 @@ class RegTree : public Model {
|
||||
* \param condition_feature the index of the feature to fix
|
||||
*/
|
||||
void CalculateContributions(const RegTree::FVec& feat,
|
||||
std::vector<float>* mean_values,
|
||||
bst_float* out_contribs, int condition = 0,
|
||||
unsigned condition_feature = 0) const;
|
||||
/*!
|
||||
@ -578,6 +579,7 @@ class RegTree : public Model {
|
||||
* \param out_contribs output vector to hold the contributions
|
||||
*/
|
||||
void CalculateContributionsApprox(const RegTree::FVec& feat,
|
||||
std::vector<float>* mean_values,
|
||||
bst_float* out_contribs) const;
|
||||
/*!
|
||||
* \brief dump the model in the requested format as a text string
|
||||
@ -589,10 +591,6 @@ class RegTree : public Model {
|
||||
std::string DumpModel(const FeatureMap& fmap,
|
||||
bool with_stats,
|
||||
std::string format) const;
|
||||
/*!
|
||||
* \brief calculate the mean value for each node, required for feature contributions
|
||||
*/
|
||||
void FillNodeMeanValues();
|
||||
/*!
|
||||
* \brief Get split type for a node.
|
||||
* \param nidx Index of node.
|
||||
@ -639,7 +637,6 @@ class RegTree : public Model {
|
||||
std::vector<int> deleted_nodes_;
|
||||
// stats of nodes
|
||||
std::vector<RTreeNodeStat> stats_;
|
||||
std::vector<bst_float> node_mean_values_;
|
||||
std::vector<FeatureType> split_types_;
|
||||
|
||||
// Categories for each internal node.
|
||||
@ -680,7 +677,6 @@ class RegTree : public Model {
|
||||
nodes_[nid].MarkDelete();
|
||||
++param.num_deleted;
|
||||
}
|
||||
bst_float FillNodeMeanValue(int nid);
|
||||
};
|
||||
|
||||
inline void RegTree::FVec::Init(size_t size) {
|
||||
|
||||
@ -213,6 +213,32 @@ void PredictBatchByBlockOfRowsKernel(
|
||||
});
|
||||
}
|
||||
|
||||
float FillNodeMeanValues(RegTree const *tree, bst_node_t nidx, std::vector<float> *mean_values) {
|
||||
bst_float result;
|
||||
auto &node = (*tree)[nidx];
|
||||
auto &node_mean_values = *mean_values;
|
||||
if (node.IsLeaf()) {
|
||||
result = node.LeafValue();
|
||||
} else {
|
||||
result = FillNodeMeanValues(tree, node.LeftChild(), mean_values) *
|
||||
tree->Stat(node.LeftChild()).sum_hess;
|
||||
result += FillNodeMeanValues(tree, node.RightChild(), mean_values) *
|
||||
tree->Stat(node.RightChild()).sum_hess;
|
||||
result /= tree->Stat(nidx).sum_hess;
|
||||
}
|
||||
node_mean_values[nidx] = result;
|
||||
return result;
|
||||
}
|
||||
|
||||
void FillNodeMeanValues(RegTree const* tree, std::vector<float>* mean_values) {
|
||||
size_t num_nodes = tree->param.num_nodes;
|
||||
if (mean_values->size() == num_nodes) {
|
||||
return;
|
||||
}
|
||||
mean_values->resize(num_nodes);
|
||||
FillNodeMeanValues(tree, 0, mean_values);
|
||||
}
|
||||
|
||||
class CPUPredictor : public Predictor {
|
||||
protected:
|
||||
// init thread buffers
|
||||
@ -396,9 +422,10 @@ class CPUPredictor : public Predictor {
|
||||
}
|
||||
}
|
||||
|
||||
void PredictContribution(DMatrix* p_fmat, HostDeviceVector<float>* out_contribs,
|
||||
const gbm::GBTreeModel& model, uint32_t ntree_limit,
|
||||
std::vector<bst_float>* tree_weights,
|
||||
void PredictContribution(DMatrix *p_fmat,
|
||||
HostDeviceVector<float> *out_contribs,
|
||||
const gbm::GBTreeModel &model, uint32_t ntree_limit,
|
||||
std::vector<bst_float> const *tree_weights,
|
||||
bool approximate, int condition,
|
||||
unsigned condition_feature) const override {
|
||||
const int nthread = omp_get_max_threads();
|
||||
@ -421,8 +448,9 @@ class CPUPredictor : public Predictor {
|
||||
// allocated one
|
||||
std::fill(contribs.begin(), contribs.end(), 0);
|
||||
// initialize tree node mean values
|
||||
std::vector<std::vector<float>> mean_values(ntree_limit);
|
||||
common::ParallelFor(bst_omp_uint(ntree_limit), [&](bst_omp_uint i) {
|
||||
model.trees[i]->FillNodeMeanValues();
|
||||
FillNodeMeanValues(model.trees[i].get(), &(mean_values[i]));
|
||||
});
|
||||
const std::vector<bst_float>& base_margin = info.base_margin_.HostVector();
|
||||
// start collecting the contributions
|
||||
@ -443,19 +471,23 @@ class CPUPredictor : public Predictor {
|
||||
feats.Fill(page[i]);
|
||||
// calculate contributions
|
||||
for (unsigned j = 0; j < ntree_limit; ++j) {
|
||||
auto *tree_mean_values = &mean_values.at(j);
|
||||
std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0);
|
||||
if (model.tree_info[j] != gid) {
|
||||
continue;
|
||||
}
|
||||
if (!approximate) {
|
||||
model.trees[j]->CalculateContributions(feats, &this_tree_contribs[0],
|
||||
condition, condition_feature);
|
||||
model.trees[j]->CalculateContributions(
|
||||
feats, tree_mean_values, &this_tree_contribs[0], condition,
|
||||
condition_feature);
|
||||
} else {
|
||||
model.trees[j]->CalculateContributionsApprox(feats, &this_tree_contribs[0]);
|
||||
model.trees[j]->CalculateContributionsApprox(
|
||||
feats, tree_mean_values, &this_tree_contribs[0]);
|
||||
}
|
||||
for (size_t ci = 0 ; ci < ncolumns ; ++ci) {
|
||||
p_contribs[ci] += this_tree_contribs[ci] *
|
||||
(tree_weights == nullptr ? 1 : (*tree_weights)[j]);
|
||||
for (size_t ci = 0; ci < ncolumns; ++ci) {
|
||||
p_contribs[ci] +=
|
||||
this_tree_contribs[ci] *
|
||||
(tree_weights == nullptr ? 1 : (*tree_weights)[j]);
|
||||
}
|
||||
}
|
||||
feats.Drop(page[i]);
|
||||
@ -470,10 +502,11 @@ class CPUPredictor : public Predictor {
|
||||
}
|
||||
}
|
||||
|
||||
void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_contribs,
|
||||
const gbm::GBTreeModel& model, unsigned ntree_limit,
|
||||
std::vector<bst_float>* tree_weights,
|
||||
bool approximate) const override {
|
||||
void PredictInteractionContributions(
|
||||
DMatrix *p_fmat, HostDeviceVector<bst_float> *out_contribs,
|
||||
const gbm::GBTreeModel &model, unsigned ntree_limit,
|
||||
std::vector<bst_float> const *tree_weights,
|
||||
bool approximate) const override {
|
||||
const MetaInfo& info = p_fmat->Info();
|
||||
const int ngroup = model.learner_model_param->num_output_group;
|
||||
size_t const ncolumns = model.learner_model_param->num_feature;
|
||||
|
||||
@ -696,7 +696,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
void PredictContribution(DMatrix* p_fmat,
|
||||
HostDeviceVector<bst_float>* out_contribs,
|
||||
const gbm::GBTreeModel& model, unsigned tree_end,
|
||||
std::vector<bst_float>*,
|
||||
std::vector<bst_float> const*,
|
||||
bool approximate, int,
|
||||
unsigned) const override {
|
||||
if (approximate) {
|
||||
@ -746,7 +746,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
HostDeviceVector<bst_float>* out_contribs,
|
||||
const gbm::GBTreeModel& model,
|
||||
unsigned tree_end,
|
||||
std::vector<bst_float>*,
|
||||
std::vector<bst_float> const*,
|
||||
bool approximate) const override {
|
||||
if (approximate) {
|
||||
LOG(FATAL) << "[Internal error]: " << __func__
|
||||
|
||||
@ -1128,36 +1128,14 @@ void RegTree::SaveModel(Json* p_out) const {
|
||||
out["default_left"] = std::move(default_left);
|
||||
}
|
||||
|
||||
void RegTree::FillNodeMeanValues() {
|
||||
size_t num_nodes = this->param.num_nodes;
|
||||
if (this->node_mean_values_.size() == num_nodes) {
|
||||
return;
|
||||
}
|
||||
this->node_mean_values_.resize(num_nodes);
|
||||
this->FillNodeMeanValue(0);
|
||||
}
|
||||
|
||||
bst_float RegTree::FillNodeMeanValue(int nid) {
|
||||
bst_float result;
|
||||
auto& node = (*this)[nid];
|
||||
if (node.IsLeaf()) {
|
||||
result = node.LeafValue();
|
||||
} else {
|
||||
result = this->FillNodeMeanValue(node.LeftChild()) * this->Stat(node.LeftChild()).sum_hess;
|
||||
result += this->FillNodeMeanValue(node.RightChild()) * this->Stat(node.RightChild()).sum_hess;
|
||||
result /= this->Stat(nid).sum_hess;
|
||||
}
|
||||
this->node_mean_values_[nid] = result;
|
||||
return result;
|
||||
}
|
||||
|
||||
void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,
|
||||
std::vector<float>* mean_values,
|
||||
bst_float *out_contribs) const {
|
||||
CHECK_GT(this->node_mean_values_.size(), 0U);
|
||||
CHECK_GT(mean_values->size(), 0U);
|
||||
// this follows the idea of http://blog.datadive.net/interpreting-random-forests/
|
||||
unsigned split_index = 0;
|
||||
// update bias value
|
||||
bst_float node_value = this->node_mean_values_[0];
|
||||
bst_float node_value = (*mean_values)[0];
|
||||
out_contribs[feat.Size()] += node_value;
|
||||
if ((*this)[0].IsLeaf()) {
|
||||
// nothing to do anymore
|
||||
@ -1172,7 +1150,7 @@ void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,
|
||||
nid = predictor::GetNextNode<true, true>((*this)[nid], nid,
|
||||
feat.GetFvalue(split_index),
|
||||
feat.IsMissing(split_index), cats);
|
||||
bst_float new_value = this->node_mean_values_[nid];
|
||||
bst_float new_value = (*mean_values)[nid];
|
||||
// update feature weight
|
||||
out_contribs[split_index] += new_value - node_value;
|
||||
node_value = new_value;
|
||||
@ -1352,12 +1330,13 @@ void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
|
||||
}
|
||||
|
||||
void RegTree::CalculateContributions(const RegTree::FVec &feat,
|
||||
std::vector<float>* mean_values,
|
||||
bst_float *out_contribs,
|
||||
int condition,
|
||||
unsigned condition_feature) const {
|
||||
// find the expected value of the tree's predictions
|
||||
if (condition == 0) {
|
||||
bst_float node_value = this->node_mean_values_[0];
|
||||
bst_float node_value = (*mean_values)[0];
|
||||
out_contribs[feat.Size()] += node_value;
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user