Fix race condition in CPU shap. (#7050)

This commit is contained in:
Jiaming Yuan
2021-06-21 10:03:15 +08:00
committed by GitHub
parent 29f8fd6fee
commit bbfffb444d
5 changed files with 68 additions and 64 deletions

View File

@@ -206,22 +206,18 @@ class Predictor {
* \param condition_feature Feature to condition on (i.e. fix) during calculations.
*/
virtual void PredictContribution(DMatrix* dmat,
HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model,
unsigned tree_end = 0,
std::vector<bst_float>* tree_weights = nullptr,
bool approximate = false,
int condition = 0,
unsigned condition_feature = 0) const = 0;
virtual void PredictInteractionContributions(DMatrix* dmat,
HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model,
unsigned tree_end = 0,
std::vector<bst_float>* tree_weights = nullptr,
bool approximate = false) const = 0;
virtual void
PredictContribution(DMatrix *dmat, HostDeviceVector<bst_float> *out_contribs,
const gbm::GBTreeModel &model, unsigned tree_end = 0,
std::vector<bst_float> const *tree_weights = nullptr,
bool approximate = false, int condition = 0,
unsigned condition_feature = 0) const = 0;
virtual void PredictInteractionContributions(
DMatrix *dmat, HostDeviceVector<bst_float> *out_contribs,
const gbm::GBTreeModel &model, unsigned tree_end = 0,
std::vector<bst_float> const *tree_weights = nullptr,
bool approximate = false) const = 0;
/**
* \brief Creates a new Predictor*.

View File

@@ -550,6 +550,7 @@ class RegTree : public Model {
* \param condition_feature the index of the feature to fix
*/
void CalculateContributions(const RegTree::FVec& feat,
std::vector<float>* mean_values,
bst_float* out_contribs, int condition = 0,
unsigned condition_feature = 0) const;
/*!
@@ -578,6 +579,7 @@ class RegTree : public Model {
* \param out_contribs output vector to hold the contributions
*/
void CalculateContributionsApprox(const RegTree::FVec& feat,
std::vector<float>* mean_values,
bst_float* out_contribs) const;
/*!
* \brief dump the model in the requested format as a text string
@@ -589,10 +591,6 @@ class RegTree : public Model {
std::string DumpModel(const FeatureMap& fmap,
bool with_stats,
std::string format) const;
/*!
* \brief calculate the mean value for each node, required for feature contributions
*/
void FillNodeMeanValues();
/*!
* \brief Get split type for a node.
* \param nidx Index of node.
@@ -639,7 +637,6 @@ class RegTree : public Model {
std::vector<int> deleted_nodes_;
// stats of nodes
std::vector<RTreeNodeStat> stats_;
std::vector<bst_float> node_mean_values_;
std::vector<FeatureType> split_types_;
// Categories for each internal node.
@@ -680,7 +677,6 @@ class RegTree : public Model {
nodes_[nid].MarkDelete();
++param.num_deleted;
}
bst_float FillNodeMeanValue(int nid);
};
inline void RegTree::FVec::Init(size_t size) {