Update GPUTreeshap (#6163)
* Reduce shap test duration * Test interoperability with shap package * Add feature interactions * Update GPUTreeShap
This commit is contained in:
@@ -352,7 +352,7 @@ class CPUPredictor : public Predictor {
|
||||
}
|
||||
}
|
||||
|
||||
void PredictContribution(DMatrix* p_fmat, std::vector<bst_float>* out_contribs,
|
||||
void PredictContribution(DMatrix* p_fmat, HostDeviceVector<float>* out_contribs,
|
||||
const gbm::GBTreeModel& model, uint32_t ntree_limit,
|
||||
std::vector<bst_float>* tree_weights,
|
||||
bool approximate, int condition,
|
||||
@@ -370,7 +370,7 @@ class CPUPredictor : public Predictor {
|
||||
size_t const ncolumns = model.learner_model_param->num_feature + 1;
|
||||
CHECK_NE(ncolumns, 0);
|
||||
// allocate space for (number of features + bias) times the number of rows
|
||||
std::vector<bst_float>& contribs = *out_contribs;
|
||||
std::vector<bst_float>& contribs = out_contribs->HostVector();
|
||||
contribs.resize(info.num_row_ * ncolumns * model.learner_model_param->num_output_group);
|
||||
// make sure contributions is zeroed, we could be reusing a previously
|
||||
// allocated one
|
||||
@@ -423,7 +423,7 @@ class CPUPredictor : public Predictor {
|
||||
}
|
||||
}
|
||||
|
||||
void PredictInteractionContributions(DMatrix* p_fmat, std::vector<bst_float>* out_contribs,
|
||||
void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_contribs,
|
||||
const gbm::GBTreeModel& model, unsigned ntree_limit,
|
||||
std::vector<bst_float>* tree_weights,
|
||||
bool approximate) override {
|
||||
@@ -435,21 +435,24 @@ class CPUPredictor : public Predictor {
|
||||
const unsigned crow_chunk = ngroup * (ncolumns + 1);
|
||||
|
||||
// allocate space for (number of features^2) times the number of rows and tmp off/on contribs
|
||||
std::vector<bst_float>& contribs = *out_contribs;
|
||||
std::vector<bst_float>& contribs = out_contribs->HostVector();
|
||||
contribs.resize(info.num_row_ * ngroup * (ncolumns + 1) * (ncolumns + 1));
|
||||
std::vector<bst_float> contribs_off(info.num_row_ * ngroup * (ncolumns + 1));
|
||||
std::vector<bst_float> contribs_on(info.num_row_ * ngroup * (ncolumns + 1));
|
||||
std::vector<bst_float> contribs_diag(info.num_row_ * ngroup * (ncolumns + 1));
|
||||
HostDeviceVector<bst_float> contribs_off_hdv(info.num_row_ * ngroup * (ncolumns + 1));
|
||||
auto &contribs_off = contribs_off_hdv.HostVector();
|
||||
HostDeviceVector<bst_float> contribs_on_hdv(info.num_row_ * ngroup * (ncolumns + 1));
|
||||
auto &contribs_on = contribs_on_hdv.HostVector();
|
||||
HostDeviceVector<bst_float> contribs_diag_hdv(info.num_row_ * ngroup * (ncolumns + 1));
|
||||
auto &contribs_diag = contribs_diag_hdv.HostVector();
|
||||
|
||||
// Compute the difference in effects when conditioning on each of the features on and off
|
||||
// see: Axiomatic characterizations of probabilistic and
|
||||
// cardinal-probabilistic interaction indices
|
||||
PredictContribution(p_fmat, &contribs_diag, model, ntree_limit,
|
||||
PredictContribution(p_fmat, &contribs_diag_hdv, model, ntree_limit,
|
||||
tree_weights, approximate, 0, 0);
|
||||
for (size_t i = 0; i < ncolumns + 1; ++i) {
|
||||
PredictContribution(p_fmat, &contribs_off, model, ntree_limit,
|
||||
PredictContribution(p_fmat, &contribs_off_hdv, model, ntree_limit,
|
||||
tree_weights, approximate, -1, i);
|
||||
PredictContribution(p_fmat, &contribs_on, model, ntree_limit,
|
||||
PredictContribution(p_fmat, &contribs_on_hdv, model, ntree_limit,
|
||||
tree_weights, approximate, 1, i);
|
||||
|
||||
for (size_t j = 0; j < info.num_row_; ++j) {
|
||||
|
||||
@@ -553,7 +553,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
}
|
||||
|
||||
void PredictContribution(DMatrix* p_fmat,
|
||||
std::vector<bst_float>* out_contribs,
|
||||
HostDeviceVector<bst_float>* out_contribs,
|
||||
const gbm::GBTreeModel& model, unsigned ntree_limit,
|
||||
std::vector<bst_float>* tree_weights,
|
||||
bool approximate, int condition,
|
||||
@@ -564,6 +564,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
|
||||
out_contribs->SetDevice(generic_param_->gpu_id);
|
||||
uint32_t real_ntree_limit =
|
||||
ntree_limit * model.learner_model_param->num_output_group;
|
||||
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
|
||||
@@ -573,22 +574,21 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
const int ngroup = model.learner_model_param->num_output_group;
|
||||
CHECK_NE(ngroup, 0);
|
||||
// allocate space for (number of features + bias) times the number of rows
|
||||
std::vector<bst_float>& contribs = *out_contribs;
|
||||
size_t contributions_columns =
|
||||
model.learner_model_param->num_feature + 1; // +1 for bias
|
||||
contribs.resize(p_fmat->Info().num_row_ * contributions_columns *
|
||||
out_contribs->Resize(p_fmat->Info().num_row_ * contributions_columns *
|
||||
model.learner_model_param->num_output_group);
|
||||
dh::TemporaryArray<float> phis(contribs.size(), 0.0);
|
||||
out_contribs->Fill(0.0f);
|
||||
auto phis = out_contribs->DeviceSpan();
|
||||
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id);
|
||||
const auto margin = p_fmat->Info().base_margin_.ConstDeviceSpan();
|
||||
float base_score = model.learner_model_param->base_score;
|
||||
auto d_phis = phis.data().get();
|
||||
// Add the base margin term to last column
|
||||
dh::LaunchN(
|
||||
generic_param_->gpu_id,
|
||||
p_fmat->Info().num_row_ * model.learner_model_param->num_output_group,
|
||||
[=] __device__(size_t idx) {
|
||||
d_phis[(idx + 1) * contributions_columns - 1] =
|
||||
phis[(idx + 1) * contributions_columns - 1] =
|
||||
margin.empty() ? base_score : margin[idx];
|
||||
});
|
||||
|
||||
@@ -602,11 +602,67 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
model.learner_model_param->num_feature);
|
||||
gpu_treeshap::GPUTreeShap(
|
||||
X, device_paths.begin(), device_paths.end(), ngroup,
|
||||
phis.data().get() + batch.base_rowid * contributions_columns);
|
||||
phis.data() + batch.base_rowid * contributions_columns, phis.size());
|
||||
}
|
||||
}
|
||||
|
||||
void PredictInteractionContributions(DMatrix* p_fmat,
|
||||
HostDeviceVector<bst_float>* out_contribs,
|
||||
const gbm::GBTreeModel& model,
|
||||
unsigned ntree_limit,
|
||||
std::vector<bst_float>* tree_weights,
|
||||
bool approximate) override {
|
||||
if (approximate) {
|
||||
LOG(FATAL) << "[Internal error]: " << __func__
|
||||
<< " approximate is not implemented in GPU Predictor.";
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
|
||||
out_contribs->SetDevice(generic_param_->gpu_id);
|
||||
uint32_t real_ntree_limit =
|
||||
ntree_limit * model.learner_model_param->num_output_group;
|
||||
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
|
||||
real_ntree_limit = static_cast<uint32_t>(model.trees.size());
|
||||
}
|
||||
|
||||
const int ngroup = model.learner_model_param->num_output_group;
|
||||
CHECK_NE(ngroup, 0);
|
||||
// allocate space for (number of features + bias) times the number of rows
|
||||
size_t contributions_columns =
|
||||
model.learner_model_param->num_feature + 1; // +1 for bias
|
||||
out_contribs->Resize(p_fmat->Info().num_row_ * contributions_columns *
|
||||
contributions_columns *
|
||||
model.learner_model_param->num_output_group);
|
||||
out_contribs->Fill(0.0f);
|
||||
auto phis = out_contribs->DeviceSpan();
|
||||
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id);
|
||||
const auto margin = p_fmat->Info().base_margin_.ConstDeviceSpan();
|
||||
float base_score = model.learner_model_param->base_score;
|
||||
// Add the base margin term to last column
|
||||
size_t n_features = model.learner_model_param->num_feature;
|
||||
dh::LaunchN(
|
||||
generic_param_->gpu_id,
|
||||
p_fmat->Info().num_row_ * model.learner_model_param->num_output_group,
|
||||
[=] __device__(size_t idx) {
|
||||
size_t group = idx % ngroup;
|
||||
size_t row_idx = idx / ngroup;
|
||||
phis[gpu_treeshap::IndexPhiInteractions(
|
||||
row_idx, ngroup, group, n_features, n_features, n_features)] =
|
||||
margin.empty() ? base_score : margin[idx];
|
||||
});
|
||||
|
||||
dh::device_vector<gpu_treeshap::PathElement> device_paths;
|
||||
ExtractPaths(&device_paths, model, real_ntree_limit,
|
||||
generic_param_->gpu_id);
|
||||
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
batch.data.SetDevice(generic_param_->gpu_id);
|
||||
batch.offset.SetDevice(generic_param_->gpu_id);
|
||||
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
|
||||
model.learner_model_param->num_feature);
|
||||
gpu_treeshap::GPUTreeShapInteractions(
|
||||
X, device_paths.begin(), device_paths.end(), ngroup,
|
||||
phis.data() + batch.base_rowid * contributions_columns, phis.size());
|
||||
}
|
||||
dh::safe_cuda(cudaMemcpy(contribs.data(), phis.data().get(),
|
||||
sizeof(float) * phis.size(),
|
||||
cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
protected:
|
||||
@@ -640,16 +696,6 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
<< " is not implemented in GPU Predictor.";
|
||||
}
|
||||
|
||||
void PredictInteractionContributions(DMatrix* p_fmat,
|
||||
std::vector<bst_float>* out_contribs,
|
||||
const gbm::GBTreeModel& model,
|
||||
unsigned ntree_limit,
|
||||
std::vector<bst_float>* tree_weights,
|
||||
bool approximate) override {
|
||||
LOG(FATAL) << "[Internal error]: " << __func__
|
||||
<< " is not implemented in GPU Predictor.";
|
||||
}
|
||||
|
||||
void Configure(const std::vector<std::pair<std::string, std::string>>& cfg) override {
|
||||
Predictor::Configure(cfg);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user