Update GPUTreeshap (#6163)

* Reduce shap test duration

* Test interoperability with shap package

* Add feature interactions

* Update GPUTreeShap
This commit is contained in:
Rory Mitchell 2020-09-28 09:43:47 +13:00 committed by GitHub
parent 434a3f35a3
commit dda9e1e487
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 176 additions and 87 deletions

@ -1 +1 @@
Subproject commit 3fd8bb118ffc8516f86417253aacdae1531977f4 Subproject commit 5f33132d75482338f78cfba562791d8445e157f6

View File

@ -147,13 +147,13 @@ class GradientBooster : public Model, public Configurable {
* \param condition_feature feature to condition on (i.e. fix) during calculations * \param condition_feature feature to condition on (i.e. fix) during calculations
*/ */
virtual void PredictContribution(DMatrix* dmat, virtual void PredictContribution(DMatrix* dmat,
std::vector<bst_float>* out_contribs, HostDeviceVector<bst_float>* out_contribs,
unsigned ntree_limit = 0, unsigned ntree_limit = 0,
bool approximate = false, int condition = 0, bool approximate = false, int condition = 0,
unsigned condition_feature = 0) = 0; unsigned condition_feature = 0) = 0;
virtual void PredictInteractionContributions(DMatrix* dmat, virtual void PredictInteractionContributions(DMatrix* dmat,
std::vector<bst_float>* out_contribs, HostDeviceVector<bst_float>* out_contribs,
unsigned ntree_limit, bool approximate) = 0; unsigned ntree_limit, bool approximate) = 0;
/*! /*!

View File

@ -201,7 +201,7 @@ class Predictor {
*/ */
virtual void PredictContribution(DMatrix* dmat, virtual void PredictContribution(DMatrix* dmat,
std::vector<bst_float>* out_contribs, HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, const gbm::GBTreeModel& model,
unsigned ntree_limit = 0, unsigned ntree_limit = 0,
std::vector<bst_float>* tree_weights = nullptr, std::vector<bst_float>* tree_weights = nullptr,
@ -210,7 +210,7 @@ class Predictor {
unsigned condition_feature = 0) = 0; unsigned condition_feature = 0) = 0;
virtual void PredictInteractionContributions(DMatrix* dmat, virtual void PredictInteractionContributions(DMatrix* dmat,
std::vector<bst_float>* out_contribs, HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, const gbm::GBTreeModel& model,
unsigned ntree_limit = 0, unsigned ntree_limit = 0,
std::vector<bst_float>* tree_weights = nullptr, std::vector<bst_float>* tree_weights = nullptr,

View File

@ -155,7 +155,7 @@ class GBLinear : public GradientBooster {
} }
void PredictContribution(DMatrix* p_fmat, void PredictContribution(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs, HostDeviceVector<bst_float>* out_contribs,
unsigned ntree_limit, bool approximate, int condition = 0, unsigned ntree_limit, bool approximate, int condition = 0,
unsigned condition_feature = 0) override { unsigned condition_feature = 0) override {
model_.LazyInitModel(); model_.LazyInitModel();
@ -165,7 +165,7 @@ class GBLinear : public GradientBooster {
const int ngroup = model_.learner_model_param->num_output_group; const int ngroup = model_.learner_model_param->num_output_group;
const size_t ncolumns = model_.learner_model_param->num_feature + 1; const size_t ncolumns = model_.learner_model_param->num_feature + 1;
// allocate space for (#features + bias) times #groups times #rows // allocate space for (#features + bias) times #groups times #rows
std::vector<bst_float>& contribs = *out_contribs; std::vector<bst_float>& contribs = out_contribs->HostVector();
contribs.resize(p_fmat->Info().num_row_ * ncolumns * ngroup); contribs.resize(p_fmat->Info().num_row_ * ncolumns * ngroup);
// make sure contributions is zeroed, we could be reusing a previously allocated one // make sure contributions is zeroed, we could be reusing a previously allocated one
std::fill(contribs.begin(), contribs.end(), 0); std::fill(contribs.begin(), contribs.end(), 0);
@ -195,9 +195,9 @@ class GBLinear : public GradientBooster {
} }
void PredictInteractionContributions(DMatrix* p_fmat, void PredictInteractionContributions(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs, HostDeviceVector<bst_float>* out_contribs,
unsigned ntree_limit, bool approximate) override { unsigned ntree_limit, bool approximate) override {
std::vector<bst_float>& contribs = *out_contribs; std::vector<bst_float>& contribs = out_contribs->HostVector();
// linear models have no interaction effects // linear models have no interaction effects
const size_t nelements = model_.learner_model_param->num_feature * const size_t nelements = model_.learner_model_param->num_feature *

View File

@ -600,7 +600,7 @@ class Dart : public GBTree {
} }
void PredictContribution(DMatrix* p_fmat, void PredictContribution(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs, HostDeviceVector<bst_float>* out_contribs,
unsigned ntree_limit, bool approximate, int condition, unsigned ntree_limit, bool approximate, int condition,
unsigned condition_feature) override { unsigned condition_feature) override {
CHECK(configured_); CHECK(configured_);
@ -609,7 +609,7 @@ class Dart : public GBTree {
} }
void PredictInteractionContributions(DMatrix* p_fmat, void PredictInteractionContributions(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs, HostDeviceVector<bst_float>* out_contribs,
unsigned ntree_limit, bool approximate) override { unsigned ntree_limit, bool approximate) override {
CHECK(configured_); CHECK(configured_);
cpu_predictor_->PredictInteractionContributions(p_fmat, out_contribs, model_, cpu_predictor_->PredictInteractionContributions(p_fmat, out_contribs, model_,

View File

@ -237,7 +237,7 @@ class GBTree : public GradientBooster {
} }
void PredictContribution(DMatrix* p_fmat, void PredictContribution(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs, HostDeviceVector<bst_float>* out_contribs,
unsigned ntree_limit, bool approximate, unsigned ntree_limit, bool approximate,
int condition, unsigned condition_feature) override { int condition, unsigned condition_feature) override {
CHECK(configured_); CHECK(configured_);
@ -246,10 +246,10 @@ class GBTree : public GradientBooster {
} }
void PredictInteractionContributions(DMatrix* p_fmat, void PredictInteractionContributions(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs, HostDeviceVector<bst_float>* out_contribs,
unsigned ntree_limit, bool approximate) override { unsigned ntree_limit, bool approximate) override {
CHECK(configured_); CHECK(configured_);
cpu_predictor_->PredictInteractionContributions(p_fmat, out_contribs, model_, this->GetPredictor()->PredictInteractionContributions(p_fmat, out_contribs, model_,
ntree_limit, nullptr, approximate); ntree_limit, nullptr, approximate);
} }

View File

@ -1068,9 +1068,9 @@ class LearnerImpl : public LearnerIO {
this->Configure(); this->Configure();
CHECK_LE(multiple_predictions, 1) << "Perform one kind of prediction at a time."; CHECK_LE(multiple_predictions, 1) << "Perform one kind of prediction at a time.";
if (pred_contribs) { if (pred_contribs) {
gbm_->PredictContribution(data.get(), &out_preds->HostVector(), ntree_limit, approx_contribs); gbm_->PredictContribution(data.get(), out_preds, ntree_limit, approx_contribs);
} else if (pred_interactions) { } else if (pred_interactions) {
gbm_->PredictInteractionContributions(data.get(), &out_preds->HostVector(), ntree_limit, gbm_->PredictInteractionContributions(data.get(), out_preds, ntree_limit,
approx_contribs); approx_contribs);
} else if (pred_leaf) { } else if (pred_leaf) {
gbm_->PredictLeaf(data.get(), &out_preds->HostVector(), ntree_limit); gbm_->PredictLeaf(data.get(), &out_preds->HostVector(), ntree_limit);

View File

@ -352,7 +352,7 @@ class CPUPredictor : public Predictor {
} }
} }
void PredictContribution(DMatrix* p_fmat, std::vector<bst_float>* out_contribs, void PredictContribution(DMatrix* p_fmat, HostDeviceVector<float>* out_contribs,
const gbm::GBTreeModel& model, uint32_t ntree_limit, const gbm::GBTreeModel& model, uint32_t ntree_limit,
std::vector<bst_float>* tree_weights, std::vector<bst_float>* tree_weights,
bool approximate, int condition, bool approximate, int condition,
@ -370,7 +370,7 @@ class CPUPredictor : public Predictor {
size_t const ncolumns = model.learner_model_param->num_feature + 1; size_t const ncolumns = model.learner_model_param->num_feature + 1;
CHECK_NE(ncolumns, 0); CHECK_NE(ncolumns, 0);
// allocate space for (number of features + bias) times the number of rows // allocate space for (number of features + bias) times the number of rows
std::vector<bst_float>& contribs = *out_contribs; std::vector<bst_float>& contribs = out_contribs->HostVector();
contribs.resize(info.num_row_ * ncolumns * model.learner_model_param->num_output_group); contribs.resize(info.num_row_ * ncolumns * model.learner_model_param->num_output_group);
// make sure contributions is zeroed, we could be reusing a previously // make sure contributions is zeroed, we could be reusing a previously
// allocated one // allocated one
@ -423,7 +423,7 @@ class CPUPredictor : public Predictor {
} }
} }
void PredictInteractionContributions(DMatrix* p_fmat, std::vector<bst_float>* out_contribs, void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, unsigned ntree_limit, const gbm::GBTreeModel& model, unsigned ntree_limit,
std::vector<bst_float>* tree_weights, std::vector<bst_float>* tree_weights,
bool approximate) override { bool approximate) override {
@ -435,21 +435,24 @@ class CPUPredictor : public Predictor {
const unsigned crow_chunk = ngroup * (ncolumns + 1); const unsigned crow_chunk = ngroup * (ncolumns + 1);
// allocate space for (number of features^2) times the number of rows and tmp off/on contribs // allocate space for (number of features^2) times the number of rows and tmp off/on contribs
std::vector<bst_float>& contribs = *out_contribs; std::vector<bst_float>& contribs = out_contribs->HostVector();
contribs.resize(info.num_row_ * ngroup * (ncolumns + 1) * (ncolumns + 1)); contribs.resize(info.num_row_ * ngroup * (ncolumns + 1) * (ncolumns + 1));
std::vector<bst_float> contribs_off(info.num_row_ * ngroup * (ncolumns + 1)); HostDeviceVector<bst_float> contribs_off_hdv(info.num_row_ * ngroup * (ncolumns + 1));
std::vector<bst_float> contribs_on(info.num_row_ * ngroup * (ncolumns + 1)); auto &contribs_off = contribs_off_hdv.HostVector();
std::vector<bst_float> contribs_diag(info.num_row_ * ngroup * (ncolumns + 1)); HostDeviceVector<bst_float> contribs_on_hdv(info.num_row_ * ngroup * (ncolumns + 1));
auto &contribs_on = contribs_on_hdv.HostVector();
HostDeviceVector<bst_float> contribs_diag_hdv(info.num_row_ * ngroup * (ncolumns + 1));
auto &contribs_diag = contribs_diag_hdv.HostVector();
// Compute the difference in effects when conditioning on each of the features on and off // Compute the difference in effects when conditioning on each of the features on and off
// see: Axiomatic characterizations of probabilistic and // see: Axiomatic characterizations of probabilistic and
// cardinal-probabilistic interaction indices // cardinal-probabilistic interaction indices
PredictContribution(p_fmat, &contribs_diag, model, ntree_limit, PredictContribution(p_fmat, &contribs_diag_hdv, model, ntree_limit,
tree_weights, approximate, 0, 0); tree_weights, approximate, 0, 0);
for (size_t i = 0; i < ncolumns + 1; ++i) { for (size_t i = 0; i < ncolumns + 1; ++i) {
PredictContribution(p_fmat, &contribs_off, model, ntree_limit, PredictContribution(p_fmat, &contribs_off_hdv, model, ntree_limit,
tree_weights, approximate, -1, i); tree_weights, approximate, -1, i);
PredictContribution(p_fmat, &contribs_on, model, ntree_limit, PredictContribution(p_fmat, &contribs_on_hdv, model, ntree_limit,
tree_weights, approximate, 1, i); tree_weights, approximate, 1, i);
for (size_t j = 0; j < info.num_row_; ++j) { for (size_t j = 0; j < info.num_row_; ++j) {

View File

@ -553,7 +553,7 @@ class GPUPredictor : public xgboost::Predictor {
} }
void PredictContribution(DMatrix* p_fmat, void PredictContribution(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs, HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, unsigned ntree_limit, const gbm::GBTreeModel& model, unsigned ntree_limit,
std::vector<bst_float>* tree_weights, std::vector<bst_float>* tree_weights,
bool approximate, int condition, bool approximate, int condition,
@ -564,6 +564,7 @@ class GPUPredictor : public xgboost::Predictor {
} }
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id)); dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
out_contribs->SetDevice(generic_param_->gpu_id);
uint32_t real_ntree_limit = uint32_t real_ntree_limit =
ntree_limit * model.learner_model_param->num_output_group; ntree_limit * model.learner_model_param->num_output_group;
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) { if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
@ -573,22 +574,21 @@ class GPUPredictor : public xgboost::Predictor {
const int ngroup = model.learner_model_param->num_output_group; const int ngroup = model.learner_model_param->num_output_group;
CHECK_NE(ngroup, 0); CHECK_NE(ngroup, 0);
// allocate space for (number of features + bias) times the number of rows // allocate space for (number of features + bias) times the number of rows
std::vector<bst_float>& contribs = *out_contribs;
size_t contributions_columns = size_t contributions_columns =
model.learner_model_param->num_feature + 1; // +1 for bias model.learner_model_param->num_feature + 1; // +1 for bias
contribs.resize(p_fmat->Info().num_row_ * contributions_columns * out_contribs->Resize(p_fmat->Info().num_row_ * contributions_columns *
model.learner_model_param->num_output_group); model.learner_model_param->num_output_group);
dh::TemporaryArray<float> phis(contribs.size(), 0.0); out_contribs->Fill(0.0f);
auto phis = out_contribs->DeviceSpan();
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id); p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id);
const auto margin = p_fmat->Info().base_margin_.ConstDeviceSpan(); const auto margin = p_fmat->Info().base_margin_.ConstDeviceSpan();
float base_score = model.learner_model_param->base_score; float base_score = model.learner_model_param->base_score;
auto d_phis = phis.data().get();
// Add the base margin term to last column // Add the base margin term to last column
dh::LaunchN( dh::LaunchN(
generic_param_->gpu_id, generic_param_->gpu_id,
p_fmat->Info().num_row_ * model.learner_model_param->num_output_group, p_fmat->Info().num_row_ * model.learner_model_param->num_output_group,
[=] __device__(size_t idx) { [=] __device__(size_t idx) {
d_phis[(idx + 1) * contributions_columns - 1] = phis[(idx + 1) * contributions_columns - 1] =
margin.empty() ? base_score : margin[idx]; margin.empty() ? base_score : margin[idx];
}); });
@ -602,11 +602,67 @@ class GPUPredictor : public xgboost::Predictor {
model.learner_model_param->num_feature); model.learner_model_param->num_feature);
gpu_treeshap::GPUTreeShap( gpu_treeshap::GPUTreeShap(
X, device_paths.begin(), device_paths.end(), ngroup, X, device_paths.begin(), device_paths.end(), ngroup,
phis.data().get() + batch.base_rowid * contributions_columns); phis.data() + batch.base_rowid * contributions_columns, phis.size());
}
}
void PredictInteractionContributions(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model,
unsigned ntree_limit,
std::vector<bst_float>* tree_weights,
bool approximate) override {
if (approximate) {
LOG(FATAL) << "[Internal error]: " << __func__
<< " approximate is not implemented in GPU Predictor.";
}
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
out_contribs->SetDevice(generic_param_->gpu_id);
uint32_t real_ntree_limit =
ntree_limit * model.learner_model_param->num_output_group;
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
real_ntree_limit = static_cast<uint32_t>(model.trees.size());
}
const int ngroup = model.learner_model_param->num_output_group;
CHECK_NE(ngroup, 0);
// allocate space for (number of features + bias) times the number of rows
size_t contributions_columns =
model.learner_model_param->num_feature + 1; // +1 for bias
out_contribs->Resize(p_fmat->Info().num_row_ * contributions_columns *
contributions_columns *
model.learner_model_param->num_output_group);
out_contribs->Fill(0.0f);
auto phis = out_contribs->DeviceSpan();
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id);
const auto margin = p_fmat->Info().base_margin_.ConstDeviceSpan();
float base_score = model.learner_model_param->base_score;
// Add the base margin term to last column
size_t n_features = model.learner_model_param->num_feature;
dh::LaunchN(
generic_param_->gpu_id,
p_fmat->Info().num_row_ * model.learner_model_param->num_output_group,
[=] __device__(size_t idx) {
size_t group = idx % ngroup;
size_t row_idx = idx / ngroup;
phis[gpu_treeshap::IndexPhiInteractions(
row_idx, ngroup, group, n_features, n_features, n_features)] =
margin.empty() ? base_score : margin[idx];
});
dh::device_vector<gpu_treeshap::PathElement> device_paths;
ExtractPaths(&device_paths, model, real_ntree_limit,
generic_param_->gpu_id);
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(generic_param_->gpu_id);
batch.offset.SetDevice(generic_param_->gpu_id);
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
model.learner_model_param->num_feature);
gpu_treeshap::GPUTreeShapInteractions(
X, device_paths.begin(), device_paths.end(), ngroup,
phis.data() + batch.base_rowid * contributions_columns, phis.size());
} }
dh::safe_cuda(cudaMemcpy(contribs.data(), phis.data().get(),
sizeof(float) * phis.size(),
cudaMemcpyDefault));
} }
protected: protected:
@ -640,16 +696,6 @@ class GPUPredictor : public xgboost::Predictor {
<< " is not implemented in GPU Predictor."; << " is not implemented in GPU Predictor.";
} }
void PredictInteractionContributions(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs,
const gbm::GBTreeModel& model,
unsigned ntree_limit,
std::vector<bst_float>* tree_weights,
bool approximate) override {
LOG(FATAL) << "[Internal error]: " << __func__
<< " is not implemented in GPU Predictor.";
}
void Configure(const std::vector<std::pair<std::string, std::string>>& cfg) override { void Configure(const std::vector<std::pair<std::string, std::string>>& cfg) override {
Predictor::Configure(cfg); Predictor::Configure(cfg);
} }

View File

@ -29,6 +29,7 @@ dependencies:
- boto3 - boto3
- awscli - awscli
- pip: - pip:
- shap
- guzzle_sphinx_theme - guzzle_sphinx_theme
- datatable - datatable
- modin[all] - modin[all]

View File

@ -53,12 +53,14 @@ TEST(CpuPredictor, Basic) {
} }
// Test predict contribution // Test predict contribution
std::vector<float> out_contribution; HostDeviceVector<float> out_contribution_hdv;
cpu_predictor->PredictContribution(dmat.get(), &out_contribution, model); auto& out_contribution = out_contribution_hdv.HostVector();
cpu_predictor->PredictContribution(dmat.get(), &out_contribution_hdv, model);
ASSERT_EQ(out_contribution.size(), kRows * (kCols + 1)); ASSERT_EQ(out_contribution.size(), kRows * (kCols + 1));
for (size_t i = 0; i < out_contribution.size(); ++i) { for (size_t i = 0; i < out_contribution.size(); ++i) {
auto const& contri = out_contribution[i]; auto const& contri = out_contribution[i];
// shift 1 for bias, as test tree is a decision dump, only global bias is filled with LeafValue(). // shift 1 for bias, as test tree is a decision dump, only global bias is
// filled with LeafValue().
if ((i + 1) % (kCols + 1) == 0) { if ((i + 1) % (kCols + 1) == 0) {
ASSERT_EQ(out_contribution.back(), 1.5f); ASSERT_EQ(out_contribution.back(), 1.5f);
} else { } else {
@ -66,10 +68,12 @@ TEST(CpuPredictor, Basic) {
} }
} }
// Test predict contribution (approximate method) // Test predict contribution (approximate method)
cpu_predictor->PredictContribution(dmat.get(), &out_contribution, model, 0, nullptr, true); cpu_predictor->PredictContribution(dmat.get(), &out_contribution_hdv, model,
0, nullptr, true);
for (size_t i = 0; i < out_contribution.size(); ++i) { for (size_t i = 0; i < out_contribution.size(); ++i) {
auto const& contri = out_contribution[i]; auto const& contri = out_contribution[i];
// shift 1 for bias, as test tree is a decision dump, only global bias is filled with LeafValue(). // shift 1 for bias, as test tree is a decision dump, only global bias is
// filled with LeafValue().
if ((i + 1) % (kCols + 1) == 0) { if ((i + 1) % (kCols + 1) == 0) {
ASSERT_EQ(out_contribution.back(), 1.5f); ASSERT_EQ(out_contribution.back(), 1.5f);
} else { } else {
@ -112,8 +116,9 @@ TEST(CpuPredictor, ExternalMemory) {
} }
// Test predict contribution // Test predict contribution
std::vector<float> out_contribution; HostDeviceVector<float> out_contribution_hdv;
cpu_predictor->PredictContribution(dmat.get(), &out_contribution, model); auto& out_contribution = out_contribution_hdv.HostVector();
cpu_predictor->PredictContribution(dmat.get(), &out_contribution_hdv, model);
ASSERT_EQ(out_contribution.size(), dmat->Info().num_row_ * (dmat->Info().num_col_ + 1)); ASSERT_EQ(out_contribution.size(), dmat->Info().num_row_ * (dmat->Info().num_col_ + 1));
for (size_t i = 0; i < out_contribution.size(); ++i) { for (size_t i = 0; i < out_contribution.size(); ++i) {
auto const& contri = out_contribution[i]; auto const& contri = out_contribution[i];
@ -126,8 +131,10 @@ TEST(CpuPredictor, ExternalMemory) {
} }
// Test predict contribution (approximate method) // Test predict contribution (approximate method)
std::vector<float> out_contribution_approximate; HostDeviceVector<float> out_contribution_approximate_hdv;
cpu_predictor->PredictContribution(dmat.get(), &out_contribution_approximate, model, 0, nullptr, true); auto& out_contribution_approximate = out_contribution_approximate_hdv.HostVector();
cpu_predictor->PredictContribution(
dmat.get(), &out_contribution_approximate_hdv, model, 0, nullptr, true);
ASSERT_EQ(out_contribution_approximate.size(), ASSERT_EQ(out_contribution_approximate.size(),
dmat->Info().num_row_ * (dmat->Info().num_col_ + 1)); dmat->Info().num_row_ * (dmat->Info().num_col_ + 1));
for (size_t i = 0; i < out_contribution.size(); ++i) { for (size_t i = 0; i < out_contribution.size(); ++i) {

View File

@ -176,12 +176,13 @@ TEST(GPUPredictor, ShapStump) {
model.CommitModel(std::move(trees), 0); model.CommitModel(std::move(trees), 0);
auto gpu_lparam = CreateEmptyGenericParam(0); auto gpu_lparam = CreateEmptyGenericParam(0);
std::unique_ptr<Predictor> gpu_predictor = std::unique_ptr<Predictor> gpu_predictor = std::unique_ptr<Predictor>(
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &gpu_lparam)); Predictor::Create("gpu_predictor", &gpu_lparam));
gpu_predictor->Configure({}); gpu_predictor->Configure({});
std::vector<float > phis; HostDeviceVector<float> predictions;
auto dmat = RandomDataGenerator(3, 1, 0).GenerateDMatrix(); auto dmat = RandomDataGenerator(3, 1, 0).GenerateDMatrix();
gpu_predictor->PredictContribution(dmat.get(), &phis, model); gpu_predictor->PredictContribution(dmat.get(), &predictions, model);
auto& phis = predictions.HostVector();
EXPECT_EQ(phis[0], 0.0); EXPECT_EQ(phis[0], 0.0);
EXPECT_EQ(phis[1], param.base_score); EXPECT_EQ(phis[1], param.base_score);
EXPECT_EQ(phis[2], 0.0); EXPECT_EQ(phis[2], 0.0);
@ -202,19 +203,20 @@ TEST(GPUPredictor, Shap) {
auto gpu_lparam = CreateEmptyGenericParam(0); auto gpu_lparam = CreateEmptyGenericParam(0);
auto cpu_lparam = CreateEmptyGenericParam(-1); auto cpu_lparam = CreateEmptyGenericParam(-1);
std::unique_ptr<Predictor> gpu_predictor = std::unique_ptr<Predictor> gpu_predictor = std::unique_ptr<Predictor>(
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &gpu_lparam)); Predictor::Create("gpu_predictor", &gpu_lparam));
std::unique_ptr<Predictor> cpu_predictor = std::unique_ptr<Predictor> cpu_predictor = std::unique_ptr<Predictor>(
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", &cpu_lparam)); Predictor::Create("cpu_predictor", &cpu_lparam));
gpu_predictor->Configure({}); gpu_predictor->Configure({});
cpu_predictor->Configure({}); cpu_predictor->Configure({});
std::vector<float > phis; HostDeviceVector<float> predictions;
std::vector<float > cpu_phis; HostDeviceVector<float> cpu_predictions;
auto dmat = RandomDataGenerator(3, 1, 0).GenerateDMatrix(); auto dmat = RandomDataGenerator(3, 1, 0).GenerateDMatrix();
gpu_predictor->PredictContribution(dmat.get(), &phis, model); gpu_predictor->PredictContribution(dmat.get(), &predictions, model);
cpu_predictor->PredictContribution(dmat.get(), &cpu_phis, model); cpu_predictor->PredictContribution(dmat.get(), &cpu_predictions, model);
for(auto i = 0ull; i < phis.size(); i++) auto& phis = predictions.HostVector();
{ auto& cpu_phis = cpu_predictions.HostVector();
for (auto i = 0ull; i < phis.size(); i++) {
EXPECT_NEAR(cpu_phis[i], phis[i], 1e-3); EXPECT_NEAR(cpu_phis[i], phis[i], 1e-3);
} }
} }

View File

@ -16,7 +16,7 @@ shap_parameter_strategy = strategies.fixed_dictionaries({
'max_depth': strategies.integers(0, 11), 'max_depth': strategies.integers(0, 11),
'max_leaves': strategies.integers(0, 256), 'max_leaves': strategies.integers(0, 256),
'num_parallel_tree': strategies.sampled_from([1, 10]), 'num_parallel_tree': strategies.sampled_from([1, 10]),
}) }).filter(lambda x: x['max_depth'] > 0 or x['max_leaves'] > 0)
class TestGPUPredict(unittest.TestCase): class TestGPUPredict(unittest.TestCase):
@ -194,26 +194,31 @@ class TestGPUPredict(unittest.TestCase):
for i in range(10): for i in range(10):
run_threaded_predict(X, rows, predict_df) run_threaded_predict(X, rows, predict_df)
@given(strategies.integers(1, 200), @given(strategies.integers(1, 10),
tm.dataset_strategy, shap_parameter_strategy, strategies.booleans()) tm.dataset_strategy, shap_parameter_strategy)
@settings(deadline=None) @settings(deadline=None)
def test_shap(self, num_rounds, dataset, param, all_rows): def test_shap(self, num_rounds, dataset, param):
if param['max_depth'] == 0 and param['max_leaves'] == 0:
return
param.update({"predictor": "gpu_predictor", "gpu_id": 0}) param.update({"predictor": "gpu_predictor", "gpu_id": 0})
param = dataset.set_params(param) param = dataset.set_params(param)
dmat = dataset.get_dmat() dmat = dataset.get_dmat()
bst = xgb.train(param, dmat, num_rounds) bst = xgb.train(param, dmat, num_rounds)
if all_rows:
test_dmat = xgb.DMatrix(dataset.X, dataset.y, dataset.w, dataset.margin) test_dmat = xgb.DMatrix(dataset.X, dataset.y, dataset.w, dataset.margin)
else:
test_dmat = xgb.DMatrix(dataset.X[0:1, :])
shap = bst.predict(test_dmat, pred_contribs=True) shap = bst.predict(test_dmat, pred_contribs=True)
bst.set_param({"predictor": "cpu_predictor"})
cpu_shap = bst.predict(test_dmat, pred_contribs=True)
margin = bst.predict(test_dmat, output_margin=True) margin = bst.predict(test_dmat, output_margin=True)
assert np.allclose(shap, cpu_shap, 1e-3, 1e-3)
# feature contributions should add up to predictions
assume(len(dataset.y) > 0) assume(len(dataset.y) > 0)
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-3, 1e-3) assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-3, 1e-3)
@given(strategies.integers(1, 10),
tm.dataset_strategy, shap_parameter_strategy)
@settings(deadline=None, max_examples=20)
def test_shap_interactions(self, num_rounds, dataset, param):
param.update({"predictor": "gpu_predictor", "gpu_id": 0})
param = dataset.set_params(param)
dmat = dataset.get_dmat()
bst = xgb.train(param, dmat, num_rounds)
test_dmat = xgb.DMatrix(dataset.X, dataset.y, dataset.w, dataset.margin)
shap = bst.predict(test_dmat, pred_interactions=True)
margin = bst.predict(test_dmat, output_margin=True)
assume(len(dataset.y) > 0)
assert np.allclose(np.sum(shap, axis=(len(shap.shape) - 1, len(shap.shape) - 2)), margin,
1e-3, 1e-3)

View File

@ -0,0 +1,25 @@
import numpy as np
import xgboost as xgb
import testing as tm
import pytest
try:
import shap
except ImportError:
shap = None
pass
pytestmark = pytest.mark.skipif(shap is None, reason="Requires shap package")
# Check integration is not broken from xgboost side
# Changes in binary format may cause problems
def test_with_shap():
X, y = shap.datasets.boston()
dtrain = xgb.DMatrix(X, label=y)
model = xgb.train({"learning_rate": 0.01}, dtrain, 10)
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)
margin = model.predict(dtrain, output_margin=True)
assert np.allclose(np.sum(shap_values, axis=len(shap_values.shape) - 1),
margin - explainer.expected_value, 1e-3, 1e-3)