Use DART tree weights when computing SHAPs (#5050)

This PR fixes tree weights in dart being ignored when computing contributions.

* Fix ellpack page source link.
* Add tree weights to compute contribution.
This commit is contained in:
Kodi Arfer 2019-12-03 06:55:53 -05:00 committed by Jiaming Yuan
parent 64f4361b47
commit f2277e7106
12 changed files with 88 additions and 16 deletions

View File

@ -142,6 +142,44 @@ test_that("predict feature contributions works", {
} }
}) })
test_that("SHAPs sum to predictions, with or without DART", {
d <- cbind(
x1 = rnorm(100),
x2 = rnorm(100),
x3 = rnorm(100))
y <- d[,"x1"] + d[,"x2"]^2 +
ifelse(d[,"x3"] > .5, d[,"x3"]^2, 2^d[,"x3"]) +
rnorm(100)
nrounds <- 30
for (booster in list("gbtree", "dart")) {
fit <- xgboost(
params = c(
list(
booster = booster,
objective = "reg:linear",
eval_metric = "rmse"),
if (booster == "dart")
list(rate_drop = .01, one_drop = T)),
data = d,
label = y,
nrounds = nrounds)
pr <- function(...)
predict(fit, newdata = d, ntreelimit = nrounds, ...)
pred <- pr()
shap <- pr(predcontrib = T)
shapi <- pr(predinteraction = T)
tol = 1e-5
expect_equal(rowSums(shap), pred, tol = tol)
expect_equal(apply(shapi, 1, sum), pred, tol = tol)
for (i in 1 : nrow(d))
for (f in list(rowSums, colSums))
expect_equal(f(shapi[i,,]), shap[i,], tol = tol)
}
})
test_that("xgb-attribute functionality", { test_that("xgb-attribute functionality", {
val <- "my attribute value" val <- "my attribute value"
list.val <- list(my_attr=val, a=123, b='ok') list.val <- list(my_attr=val, a=123, b='ok')

View File

@ -29,10 +29,11 @@
// data // data
#include "../src/data/data.cc" #include "../src/data/data.cc"
#include "../src/data/ellpack_page.cc"
#include "../src/data/simple_csr_source.cc" #include "../src/data/simple_csr_source.cc"
#include "../src/data/simple_dmatrix.cc" #include "../src/data/simple_dmatrix.cc"
#include "../src/data/sparse_page_raw_format.cc" #include "../src/data/sparse_page_raw_format.cc"
#include "../src/data/ellpack_page.cc"
#include "../src/data/ellpack_page_source.cc"
// prediction // prediction
#include "../src/predictor/predictor.cc" #include "../src/predictor/predictor.cc"

View File

@ -149,6 +149,7 @@ class Predictor {
* \param [in,out] out_contribs The output feature contribs. * \param [in,out] out_contribs The output feature contribs.
* \param model Model to make predictions from. * \param model Model to make predictions from.
* \param ntree_limit (Optional) The ntree limit. * \param ntree_limit (Optional) The ntree limit.
* \param tree_weights (Optional) Weights to multiply each tree by.
* \param approximate Use fast approximate algorithm. * \param approximate Use fast approximate algorithm.
* \param condition Condition on the condition_feature (0=no, -1=cond off, 1=cond on). * \param condition Condition on the condition_feature (0=no, -1=cond off, 1=cond on).
* \param condition_feature Feature to condition on (i.e. fix) during calculations. * \param condition_feature Feature to condition on (i.e. fix) during calculations.
@ -158,6 +159,7 @@ class Predictor {
std::vector<bst_float>* out_contribs, std::vector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, const gbm::GBTreeModel& model,
unsigned ntree_limit = 0, unsigned ntree_limit = 0,
std::vector<bst_float>* tree_weights = nullptr,
bool approximate = false, bool approximate = false,
int condition = 0, int condition = 0,
unsigned condition_feature = 0) = 0; unsigned condition_feature = 0) = 0;
@ -166,6 +168,7 @@ class Predictor {
std::vector<bst_float>* out_contribs, std::vector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, const gbm::GBTreeModel& model,
unsigned ntree_limit = 0, unsigned ntree_limit = 0,
std::vector<bst_float>* tree_weights = nullptr,
bool approximate = false) = 0; bool approximate = false) = 0;
/** /**

View File

@ -25,6 +25,3 @@ USE_AZURE = 0
# - librabit.a Normal distributed version. # - librabit.a Normal distributed version.
# - librabit_empty.a Non distributed mock version, # - librabit_empty.a Non distributed mock version,
LIB_RABIT = librabit_empty.a LIB_RABIT = librabit_empty.a
DMLC_CFLAGS = -DDMLC_ENABLE_STD_THREAD=0
ADD_CFLAGS = -DDMLC_ENABLE_STD_THREAD=0

View File

@ -10,6 +10,8 @@ namespace xgboost {
class EllpackPageImpl {}; class EllpackPageImpl {};
EllpackPage::EllpackPage() = default;
EllpackPage::EllpackPage(DMatrix* dmat, const BatchParam& param) { EllpackPage::EllpackPage(DMatrix* dmat, const BatchParam& param) {
LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but EllpackPage is required"; LOG(FATAL) << "Internal Error: XGBoost is not compiled with CUDA but EllpackPage is required";
} }

View File

@ -3,8 +3,8 @@
*/ */
#ifndef XGBOOST_USE_CUDA #ifndef XGBOOST_USE_CUDA
#include <xgboost/data.h>
#include "ellpack_page_source.h" #include "ellpack_page_source.h"
namespace xgboost { namespace xgboost {
namespace data { namespace data {

View File

@ -1,14 +1,14 @@
/*! /*!
* Copyright 2019 XGBoost contributors * Copyright 2019 XGBoost contributors
*/ */
#include "ellpack_page_source.h"
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "../common/hist_util.h" #include "../common/hist_util.h"
#include "ellpack_page_source.h"
#include "sparse_page_source.h"
#include "ellpack_page.cuh" #include "ellpack_page.cuh"
namespace xgboost { namespace xgboost {

View File

@ -9,7 +9,6 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include "sparse_page_source.h"
#include "../common/timer.h" #include "../common/timer.h"
namespace xgboost { namespace xgboost {

View File

@ -348,6 +348,23 @@ class Dart : public GBTree {
return GBTree::UseGPU(); return GBTree::UseGPU();
} }
void PredictContribution(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs,
unsigned ntree_limit, bool approximate, int condition,
unsigned condition_feature) override {
CHECK(configured_);
cpu_predictor_->PredictContribution(p_fmat, out_contribs, model_,
ntree_limit, &weight_drop_, approximate);
}
void PredictInteractionContributions(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs,
unsigned ntree_limit, bool approximate) override {
CHECK(configured_);
cpu_predictor_->PredictInteractionContributions(p_fmat, out_contribs, model_,
ntree_limit, &weight_drop_, approximate);
}
protected: protected:
friend class GBTree; friend class GBTree;
// internal prediction loop // internal prediction loop

View File

@ -216,7 +216,8 @@ class GBTree : public GradientBooster {
unsigned ntree_limit, bool approximate, int condition, unsigned ntree_limit, bool approximate, int condition,
unsigned condition_feature) override { unsigned condition_feature) override {
CHECK(configured_); CHECK(configured_);
cpu_predictor_->PredictContribution(p_fmat, out_contribs, model_, ntree_limit, approximate); cpu_predictor_->PredictContribution(p_fmat, out_contribs, model_,
ntree_limit, nullptr, approximate);
} }
void PredictInteractionContributions(DMatrix* p_fmat, void PredictInteractionContributions(DMatrix* p_fmat,
@ -224,7 +225,7 @@ class GBTree : public GradientBooster {
unsigned ntree_limit, bool approximate) override { unsigned ntree_limit, bool approximate) override {
CHECK(configured_); CHECK(configured_);
cpu_predictor_->PredictInteractionContributions(p_fmat, out_contribs, model_, cpu_predictor_->PredictInteractionContributions(p_fmat, out_contribs, model_,
ntree_limit, approximate); ntree_limit, nullptr, approximate);
} }
std::vector<std::string> DumpModel(const FeatureMap& fmap, std::vector<std::string> DumpModel(const FeatureMap& fmap,

View File

@ -257,6 +257,7 @@ class CPUPredictor : public Predictor {
void PredictContribution(DMatrix* p_fmat, std::vector<bst_float>* out_contribs, void PredictContribution(DMatrix* p_fmat, std::vector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, unsigned ntree_limit, const gbm::GBTreeModel& model, unsigned ntree_limit,
std::vector<bst_float>* tree_weights,
bool approximate, bool approximate,
int condition, int condition,
unsigned condition_feature) override { unsigned condition_feature) override {
@ -296,16 +297,23 @@ class CPUPredictor : public Predictor {
bst_float* p_contribs = bst_float* p_contribs =
&contribs[(row_idx * ngroup + gid) * ncolumns]; &contribs[(row_idx * ngroup + gid) * ncolumns];
feats.Fill(batch[i]); feats.Fill(batch[i]);
std::vector<bst_float> this_tree_contribs;
this_tree_contribs.resize(ncolumns);
// calculate contributions // calculate contributions
for (unsigned j = 0; j < ntree_limit; ++j) { for (unsigned j = 0; j < ntree_limit; ++j) {
std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0);
if (model.tree_info[j] != gid) { if (model.tree_info[j] != gid) {
continue; continue;
} }
if (!approximate) { if (!approximate) {
model.trees[j]->CalculateContributions(feats, root_id, p_contribs, model.trees[j]->CalculateContributions(feats, root_id, &this_tree_contribs[0],
condition, condition_feature); condition, condition_feature);
} else { } else {
model.trees[j]->CalculateContributionsApprox(feats, root_id, p_contribs); model.trees[j]->CalculateContributionsApprox(feats, root_id, &this_tree_contribs[0]);
}
for (int ci = 0 ; ci < ncolumns ; ++ci) {
p_contribs[ci] += this_tree_contribs[ci] *
(tree_weights == nullptr ? 1 : (*tree_weights)[j]);
} }
} }
feats.Drop(batch[i]); feats.Drop(batch[i]);
@ -322,6 +330,7 @@ class CPUPredictor : public Predictor {
void PredictInteractionContributions(DMatrix* p_fmat, std::vector<bst_float>* out_contribs, void PredictInteractionContributions(DMatrix* p_fmat, std::vector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, unsigned ntree_limit, const gbm::GBTreeModel& model, unsigned ntree_limit,
std::vector<bst_float>* tree_weights,
bool approximate) override { bool approximate) override {
const MetaInfo& info = p_fmat->Info(); const MetaInfo& info = p_fmat->Info();
const int ngroup = model.param.num_output_group; const int ngroup = model.param.num_output_group;
@ -340,10 +349,13 @@ class CPUPredictor : public Predictor {
// Compute the difference in effects when conditioning on each of the features on and off // Compute the difference in effects when conditioning on each of the features on and off
// see: Axiomatic characterizations of probabilistic and // see: Axiomatic characterizations of probabilistic and
// cardinal-probabilistic interaction indices // cardinal-probabilistic interaction indices
PredictContribution(p_fmat, &contribs_diag, model, ntree_limit, approximate, 0, 0); PredictContribution(p_fmat, &contribs_diag, model, ntree_limit,
tree_weights, approximate, 0, 0);
for (size_t i = 0; i < ncolumns + 1; ++i) { for (size_t i = 0; i < ncolumns + 1; ++i) {
PredictContribution(p_fmat, &contribs_off, model, ntree_limit, approximate, -1, i); PredictContribution(p_fmat, &contribs_off, model, ntree_limit,
PredictContribution(p_fmat, &contribs_on, model, ntree_limit, approximate, 1, i); tree_weights, approximate, -1, i);
PredictContribution(p_fmat, &contribs_on, model, ntree_limit,
tree_weights, approximate, 1, i);
for (size_t j = 0; j < info.num_row_; ++j) { for (size_t j = 0; j < info.num_row_; ++j) {
for (int l = 0; l < ngroup; ++l) { for (int l = 0; l < ngroup; ++l) {

View File

@ -397,6 +397,7 @@ class GPUPredictor : public xgboost::Predictor {
void PredictContribution(DMatrix* p_fmat, void PredictContribution(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs, std::vector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, unsigned ntree_limit, const gbm::GBTreeModel& model, unsigned ntree_limit,
std::vector<bst_float>* tree_weights,
bool approximate, int condition, bool approximate, int condition,
unsigned condition_feature) override { unsigned condition_feature) override {
LOG(FATAL) << "Internal error: " << __func__ LOG(FATAL) << "Internal error: " << __func__
@ -407,6 +408,7 @@ class GPUPredictor : public xgboost::Predictor {
std::vector<bst_float>* out_contribs, std::vector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, const gbm::GBTreeModel& model,
unsigned ntree_limit, unsigned ntree_limit,
std::vector<bst_float>* tree_weights,
bool approximate) override { bool approximate) override {
LOG(FATAL) << "Internal error: " << __func__ LOG(FATAL) << "Internal error: " << __func__
<< " is not implemented in GPU Predictor."; << " is not implemented in GPU Predictor.";