Add SHAP interaction effects, fix minor bug, and add cox loss (#3043)

* Add interaction effects and cox loss

* Minimize whitespace changes

* Cox loss now no longer needs a pre-sorted dataset.

* Address code review comments

* Remove mem check, rename to pred_interactions, include bias

* Make lint happy

* More lint fixes

* Fix cox loss indexing

* Fix main effects and tests

* Fix lint

* Use half interaction values on the off-diagonals

* Fix lint again
This commit is contained in:
Scott Lundberg 2018-02-07 18:38:01 -08:00 committed by Vadim Khotilovich
parent 077abb35cd
commit d878c36c84
19 changed files with 638 additions and 125 deletions

View File

@ -65,8 +65,8 @@ Parameters for Tree Booster
- 'exact': Exact greedy algorithm. - 'exact': Exact greedy algorithm.
- 'approx': Approximate greedy algorithm using sketching and histogram. - 'approx': Approximate greedy algorithm using sketching and histogram.
- 'hist': Fast histogram optimized approximate greedy algorithm. It uses some performance improvements such as bins caching. - 'hist': Fast histogram optimized approximate greedy algorithm. It uses some performance improvements such as bins caching.
- 'gpu_exact': GPU implementation of exact algorithm. - 'gpu_exact': GPU implementation of exact algorithm.
- 'gpu_hist': GPU implementation of hist algorithm. - 'gpu_hist': GPU implementation of hist algorithm.
* sketch_eps, [default=0.03] * sketch_eps, [default=0.03]
- This is only used for approximate greedy algorithm. - This is only used for approximate greedy algorithm.
- This roughly translated into ```O(1 / sketch_eps)``` number of bins. - This roughly translated into ```O(1 / sketch_eps)``` number of bins.
@ -170,6 +170,8 @@ Specify the learning task and the corresponding learning objective. The objectiv
they can only be used when the entire training session uses the same dataset they can only be used when the entire training session uses the same dataset
- "count:poisson" --poisson regression for count data, output mean of poisson distribution - "count:poisson" --poisson regression for count data, output mean of poisson distribution
- max_delta_step is set to 0.7 by default in poisson regression (used to safeguard optimization) - max_delta_step is set to 0.7 by default in poisson regression (used to safeguard optimization)
- "survival:cox" --Cox regression for right censored survival time data (negative values are considered right censored).
Note that predictions are returned on the hazard ratio scale (i.e., as HR = exp(marginal_prediction) in the proportional hazard function h(t) = h0(t) * HR).
- "multi:softmax" --set XGBoost to do multiclass classification using the softmax objective, you also need to set num_class(number of classes) - "multi:softmax" --set XGBoost to do multiclass classification using the softmax objective, you also need to set num_class(number of classes)
- "multi:softprob" --same as softmax, but output a vector of ndata * nclass, which can be further reshaped to ndata, nclass matrix. The result contains predicted probability of each data point belonging to each class. - "multi:softprob" --same as softmax, but output a vector of ndata * nclass, which can be further reshaped to ndata, nclass matrix. The result contains predicted probability of each data point belonging to each class.
- "rank:pairwise" --set XGBoost to do ranking task by minimizing the pairwise loss - "rank:pairwise" --set XGBoost to do ranking task by minimizing the pairwise loss
@ -197,6 +199,7 @@ Specify the learning task and the corresponding learning objective. The objectiv
training repeatedly training repeatedly
- "poisson-nloglik": negative log-likelihood for Poisson regression - "poisson-nloglik": negative log-likelihood for Poisson regression
- "gamma-nloglik": negative log-likelihood for gamma regression - "gamma-nloglik": negative log-likelihood for gamma regression
- "cox-nloglik": negative partial log-likelihood for Cox proportional hazards regression
- "gamma-deviance": residual deviance for gamma regression - "gamma-deviance": residual deviance for gamma regression
- "tweedie-nloglik": negative log-likelihood for Tweedie regression (at a specified value of the tweedie_variance_power parameter) - "tweedie-nloglik": negative log-likelihood for Tweedie regression (at a specified value of the tweedie_variance_power parameter)
* seed [default=0] * seed [default=0]

View File

@ -12,6 +12,7 @@
#include <string> #include <string>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <numeric>
#include "./base.h" #include "./base.h"
namespace xgboost { namespace xgboost {
@ -76,6 +77,19 @@ struct MetaInfo {
inline unsigned GetRoot(size_t i) const { inline unsigned GetRoot(size_t i) const {
return root_index.size() != 0 ? root_index[i] : 0U; return root_index.size() != 0 ? root_index[i] : 0U;
} }
/*! \brief get sorted indexes (argsort) of labels by absolute value (used by cox loss) */
inline const std::vector<size_t>& LabelAbsSort() const {
if (label_order_cache.size() == labels.size()) {
return label_order_cache;
}
label_order_cache.resize(labels.size());
std::iota(label_order_cache.begin(), label_order_cache.end(), 0);
const auto l = labels;
XGBOOST_PARALLEL_SORT(label_order_cache.begin(), label_order_cache.end(),
[&l](size_t i1, size_t i2) {return std::abs(l[i1]) < std::abs(l[i2]);});
return label_order_cache;
}
/*! \brief clear all the information */ /*! \brief clear all the information */
void Clear(); void Clear();
/*! /*!
@ -96,6 +110,10 @@ struct MetaInfo {
* \param num Number of elements in the source array. * \param num Number of elements in the source array.
*/ */
void SetInfo(const char* key, const void* dptr, DataType dtype, size_t num); void SetInfo(const char* key, const void* dptr, DataType dtype, size_t num);
private:
/*! \brief argsort of labels */
mutable std::vector<size_t> label_order_cache;
}; };
/*! \brief read-only sparse instance batch in CSR format */ /*! \brief read-only sparse instance batch in CSR format */

View File

@ -124,10 +124,17 @@ class GradientBooster {
* \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means * \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means
* we do not limit number of trees * we do not limit number of trees
* \param approximate use a faster (inconsistent) approximation of SHAP values * \param approximate use a faster (inconsistent) approximation of SHAP values
* \param condition condition on the condition_feature (0=no, -1=cond off, 1=cond on).
* \param condition_feature feature to condition on (i.e. fix) during calculations
*/ */
virtual void PredictContribution(DMatrix* dmat, virtual void PredictContribution(DMatrix* dmat,
std::vector<bst_float>* out_contribs, std::vector<bst_float>* out_contribs,
unsigned ntree_limit = 0, bool approximate = false) = 0; unsigned ntree_limit = 0, bool approximate = false,
int condition = 0, unsigned condition_feature = 0) = 0;
virtual void PredictInteractionContributions(DMatrix* dmat,
std::vector<bst_float>* out_contribs,
unsigned ntree_limit, bool approximate) = 0;
/*! /*!
* \brief dump the model in the requested format * \brief dump the model in the requested format

View File

@ -105,6 +105,7 @@ class Learner : public rabit::Serializable {
* \param pred_leaf whether to only predict the leaf index of each tree in a boosted tree predictor * \param pred_leaf whether to only predict the leaf index of each tree in a boosted tree predictor
* \param pred_contribs whether to only predict the feature contributions * \param pred_contribs whether to only predict the feature contributions
* \param approx_contribs whether to approximate the feature contributions for speed * \param approx_contribs whether to approximate the feature contributions for speed
* \param pred_interactions whether to compute the feature pair contributions
*/ */
virtual void Predict(DMatrix* data, virtual void Predict(DMatrix* data,
bool output_margin, bool output_margin,
@ -112,7 +113,9 @@ class Learner : public rabit::Serializable {
unsigned ntree_limit = 0, unsigned ntree_limit = 0,
bool pred_leaf = false, bool pred_leaf = false,
bool pred_contribs = false, bool pred_contribs = false,
bool approx_contribs = false) const = 0; bool approx_contribs = false,
bool pred_interactions = false) const = 0;
/*! /*!
* \brief Set additional attribute to the Booster. * \brief Set additional attribute to the Booster.
* The property will be saved along the booster. * The property will be saved along the booster.

View File

@ -153,14 +153,24 @@ class Predictor {
* a vector of length (nfeats + 1) * num_output_group * nsample, arranged in * a vector of length (nfeats + 1) * num_output_group * nsample, arranged in
* that order. * that order.
* *
* \param [in,out] dmat The input feature matrix. * \param [in,out] dmat The input feature matrix.
* \param [in,out] out_contribs The output feature contribs. * \param [in,out] out_contribs The output feature contribs.
* \param model Model to make predictions from. * \param model Model to make predictions from.
* \param ntree_limit (Optional) The ntree limit. * \param ntree_limit (Optional) The ntree limit.
* \param approximate Use fast approximate algorithm. * \param approximate Use fast approximate algorithm.
* \param condition Condition on the condition_feature (0=no, -1=cond off, 1=cond on).
* \param condition_feature Feature to condition on (i.e. fix) during calculations.
*/ */
virtual void PredictContribution(DMatrix* dmat, virtual void PredictContribution(DMatrix* dmat,
std::vector<bst_float>* out_contribs,
const gbm::GBTreeModel& model,
unsigned ntree_limit = 0,
bool approximate = false,
int condition = 0,
unsigned condition_feature = 0) = 0;
virtual void PredictInteractionContributions(DMatrix* dmat,
std::vector<bst_float>* out_contribs, std::vector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, const gbm::GBTreeModel& model,
unsigned ntree_limit = 0, unsigned ntree_limit = 0,

View File

@ -501,13 +501,33 @@ class RegTree: public TreeModel<bst_float, RTreeNodeStat> {
* \param feat dense feature vector, if the feature is missing the field is set to NaN * \param feat dense feature vector, if the feature is missing the field is set to NaN
* \param root_id starting root index of the instance * \param root_id starting root index of the instance
* \param out_contribs output vector to hold the contributions * \param out_contribs output vector to hold the contributions
* \param condition fix one feature to either off (-1) on (1) or not fixed (0 default)
* \param condition_feature the index of the feature to fix
*/ */
inline void CalculateContributions(const RegTree::FVec& feat, unsigned root_id, inline void CalculateContributions(const RegTree::FVec& feat, unsigned root_id,
bst_float *out_contribs) const; bst_float *out_contribs,
int condition = 0,
unsigned condition_feature = 0) const;
/*!
* \brief Recursive function that computes the feature attributions for a single tree.
* \param feat dense feature vector, if the feature is missing the field is set to NaN
* \param phi dense output vector of feature attributions
* \param node_index the index of the current node in the tree
* \param unique_depth how many unique features are above the current node in the tree
* \param parent_unique_path a vector of statistics about our current path through the tree
* \param parent_zero_fraction what fraction of the parent path weight is coming as 0 (integrated)
* \param parent_one_fraction what fraction of the parent path weight is coming as 1 (fixed)
* \param parent_feature_index what feature the parent node used to split
* \param condition fix one feature to either off (-1) on (1) or not fixed (0 default)
* \param condition_feature the index of the feature to fix
* \param condition_fraction what fraction of the current weight matches our conditioning feature
*/
inline void TreeShap(const RegTree::FVec& feat, bst_float *phi, inline void TreeShap(const RegTree::FVec& feat, bst_float *phi,
unsigned node_index, unsigned unique_depth, unsigned node_index, unsigned unique_depth,
PathElement *parent_unique_path, bst_float parent_zero_fraction, PathElement *parent_unique_path, bst_float parent_zero_fraction,
bst_float parent_one_fraction, int parent_feature_index) const; bst_float parent_one_fraction, int parent_feature_index,
int condition, unsigned condition_feature,
bst_float condition_fraction) const;
/*! /*!
* \brief calculate the approximate feature contributions for the given root * \brief calculate the approximate feature contributions for the given root
@ -700,7 +720,7 @@ inline bst_float UnwoundPathSum(const PathElement *unique_path, unsigned unique_
/ static_cast<bst_float>((i + 1) * one_fraction); / static_cast<bst_float>((i + 1) * one_fraction);
total += tmp; total += tmp;
next_one_portion = unique_path[i].pweight - tmp * zero_fraction * ((unique_depth - i) next_one_portion = unique_path[i].pweight - tmp * zero_fraction * ((unique_depth - i)
/ static_cast<bst_float>(unique_depth+1)); / static_cast<bst_float>(unique_depth + 1));
} else { } else {
total += (unique_path[i].pweight / zero_fraction) / ((unique_depth - i) total += (unique_path[i].pweight / zero_fraction) / ((unique_depth - i)
/ static_cast<bst_float>(unique_depth + 1)); / static_cast<bst_float>(unique_depth + 1));
@ -713,15 +733,22 @@ inline bst_float UnwoundPathSum(const PathElement *unique_path, unsigned unique_
inline void RegTree::TreeShap(const RegTree::FVec& feat, bst_float *phi, inline void RegTree::TreeShap(const RegTree::FVec& feat, bst_float *phi,
unsigned node_index, unsigned unique_depth, unsigned node_index, unsigned unique_depth,
PathElement *parent_unique_path, bst_float parent_zero_fraction, PathElement *parent_unique_path, bst_float parent_zero_fraction,
bst_float parent_one_fraction, int parent_feature_index) const { bst_float parent_one_fraction, int parent_feature_index,
int condition, unsigned condition_feature,
bst_float condition_fraction) const {
const auto node = (*this)[node_index]; const auto node = (*this)[node_index];
// stop if we have no weight coming down to us
if (condition_fraction == 0) return;
// extend the unique path // extend the unique path
PathElement *unique_path = parent_unique_path + unique_depth; PathElement *unique_path = parent_unique_path + unique_depth + 1;
if (unique_depth > 0) std::copy(parent_unique_path, std::copy(parent_unique_path, parent_unique_path + unique_depth + 1, unique_path);
parent_unique_path + unique_depth, unique_path);
ExtendPath(unique_path, unique_depth, parent_zero_fraction, if (condition == 0 || condition_feature != static_cast<unsigned>(parent_feature_index)) {
parent_one_fraction, parent_feature_index); ExtendPath(unique_path, unique_depth, parent_zero_fraction,
parent_one_fraction, parent_feature_index);
}
const unsigned split_index = node.split_index(); const unsigned split_index = node.split_index();
// leaf node // leaf node
@ -729,7 +756,8 @@ inline void RegTree::TreeShap(const RegTree::FVec& feat, bst_float *phi,
for (unsigned i = 1; i <= unique_depth; ++i) { for (unsigned i = 1; i <= unique_depth; ++i) {
const bst_float w = UnwoundPathSum(unique_path, unique_depth, i); const bst_float w = UnwoundPathSum(unique_path, unique_depth, i);
const PathElement &el = unique_path[i]; const PathElement &el = unique_path[i];
phi[el.feature_index] += w * (el.one_fraction - el.zero_fraction) * node.leaf_value(); phi[el.feature_index] += w * (el.one_fraction - el.zero_fraction)
* node.leaf_value() * condition_fraction;
} }
// internal node // internal node
@ -764,34 +792,44 @@ inline void RegTree::TreeShap(const RegTree::FVec& feat, bst_float *phi,
unique_depth -= 1; unique_depth -= 1;
} }
// divide up the condition_fraction among the recursive calls
bst_float hot_condition_fraction = condition_fraction;
bst_float cold_condition_fraction = condition_fraction;
if (condition > 0 && split_index == condition_feature) {
cold_condition_fraction = 0;
unique_depth -= 1;
} else if (condition < 0 && split_index == condition_feature) {
hot_condition_fraction *= hot_zero_fraction;
cold_condition_fraction *= cold_zero_fraction;
unique_depth -= 1;
}
TreeShap(feat, phi, hot_index, unique_depth + 1, unique_path, TreeShap(feat, phi, hot_index, unique_depth + 1, unique_path,
hot_zero_fraction*incoming_zero_fraction, incoming_one_fraction, split_index); hot_zero_fraction * incoming_zero_fraction, incoming_one_fraction,
split_index, condition, condition_feature, hot_condition_fraction);
TreeShap(feat, phi, cold_index, unique_depth + 1, unique_path, TreeShap(feat, phi, cold_index, unique_depth + 1, unique_path,
cold_zero_fraction*incoming_zero_fraction, 0, split_index); cold_zero_fraction * incoming_zero_fraction, 0,
split_index, condition, condition_feature, cold_condition_fraction);
} }
} }
inline void RegTree::CalculateContributions(const RegTree::FVec& feat, unsigned root_id, inline void RegTree::CalculateContributions(const RegTree::FVec& feat, unsigned root_id,
bst_float *out_contribs) const { bst_float *out_contribs,
int condition,
unsigned condition_feature) const {
// find the expected value of the tree's predictions // find the expected value of the tree's predictions
bst_float base_value = 0.0f; if (condition == 0) {
bst_float total_cover = 0.0f; bst_float node_value = this->node_mean_values[static_cast<int>(root_id)];
for (int i = 0; i < (*this).param.num_nodes; ++i) { out_contribs[feat.size()] += node_value;
const auto node = (*this)[i];
if (node.is_leaf()) {
const auto cover = this->stat(i).sum_hess;
base_value += cover * node.leaf_value();
total_cover += cover;
}
} }
out_contribs[feat.size()] += base_value / total_cover;
// Preallocate space for the unique path data // Preallocate space for the unique path data
const int maxd = this->MaxDepth(root_id) + 1; const int maxd = this->MaxDepth(root_id) + 2;
PathElement *unique_path_data = new PathElement[(maxd * (maxd + 1)) / 2]; PathElement *unique_path_data = new PathElement[(maxd * (maxd + 1)) / 2];
TreeShap(feat, out_contribs, root_id, 0, unique_path_data, 1, 1, -1); TreeShap(feat, out_contribs, root_id, 0, unique_path_data,
1, 1, -1, condition, condition_feature, 1);
delete[] unique_path_data; delete[] unique_path_data;
} }

View File

@ -992,7 +992,7 @@ class Booster(object):
return self.eval_set([(data, name)], iteration) return self.eval_set([(data, name)], iteration)
def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False, def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False,
pred_contribs=False, approx_contribs=False): pred_contribs=False, approx_contribs=False, pred_interactions=False):
""" """
Predict with data. Predict with data.
@ -1019,14 +1019,21 @@ class Booster(object):
in both tree 1 and tree 0. in both tree 1 and tree 0.
pred_contribs : bool pred_contribs : bool
When this option is on, the output will be a matrix of (nsample, nfeats+1) When this is True the output will be a matrix of size (nsample, nfeats + 1)
with each record indicating the feature contributions (SHAP values) for that with each record indicating the feature contributions (SHAP values) for that
prediction. The sum of all feature contributions is equal to the prediction. prediction. The sum of all feature contributions is equal to the raw untransformed
Note that the bias is added as the final column, on top of the regular features. margin value of the prediction. Note the final column is the bias term.
approx_contribs : bool approx_contribs : bool
Approximate the contributions of each feature Approximate the contributions of each feature
pred_interactions : bool
When this is True the output will be a matrix of size (nsample, nfeats + 1, nfeats + 1)
indicating the SHAP interaction values for each pair of features. The sum of each
row (or column) of the interaction values equals the corresponding SHAP value (from
pred_contribs), and the sum of the entire matrix equals the raw untransformed margin
value of the prediction. Note the last row and column correspond to the bias term.
Returns Returns
------- -------
prediction : numpy array prediction : numpy array
@ -1040,6 +1047,8 @@ class Booster(object):
option_mask |= 0x04 option_mask |= 0x04
if approx_contribs: if approx_contribs:
option_mask |= 0x08 option_mask |= 0x08
if pred_interactions:
option_mask |= 0x10
self._validate_features(data) self._validate_features(data)
@ -1055,8 +1064,22 @@ class Booster(object):
preds = preds.astype(np.int32) preds = preds.astype(np.int32)
nrow = data.num_row() nrow = data.num_row()
if preds.size != nrow and preds.size % nrow == 0: if preds.size != nrow and preds.size % nrow == 0:
ncol = int(preds.size / nrow) chunk_size = int(preds.size / nrow)
preds = preds.reshape(nrow, ncol)
if pred_interactions:
ngroup = int(chunk_size / ((data.num_col() + 1) * (data.num_col() + 1)))
if ngroup == 1:
preds = preds.reshape(nrow, data.num_col() + 1, data.num_col() + 1)
else:
preds = preds.reshape(nrow, ngroup, data.num_col() + 1, data.num_col() + 1)
elif pred_contribs:
ngroup = int(chunk_size / (data.num_col() + 1))
if ngroup == 1:
preds = preds.reshape(nrow, data.num_col() + 1)
else:
preds = preds.reshape(nrow, ngroup, data.num_col() + 1)
else:
preds = preds.reshape(nrow, chunk_size)
return preds return preds
def save_model(self, fname): def save_model(self, fname):

View File

@ -759,7 +759,8 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
&preds, ntree_limit, &preds, ntree_limit,
(option_mask & 2) != 0, (option_mask & 2) != 0,
(option_mask & 4) != 0, (option_mask & 4) != 0,
(option_mask & 8) != 0); (option_mask & 8) != 0,
(option_mask & 16) != 0);
*out_result = dmlc::BeginPtr(preds); *out_result = dmlc::BeginPtr(preds);
*len = static_cast<xgboost::bst_ulong>(preds.size()); *len = static_cast<xgboost::bst_ulong>(preds.size());
API_END(); API_END();

View File

@ -224,7 +224,8 @@ class GBLinear : public GradientBooster {
void PredictContribution(DMatrix* p_fmat, void PredictContribution(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs, std::vector<bst_float>* out_contribs,
unsigned ntree_limit, bool approximate) override { unsigned ntree_limit, bool approximate, int condition = 0,
unsigned condition_feature = 0) override {
if (model.weight.size() == 0) { if (model.weight.size() == 0) {
model.InitModel(); model.InitModel();
} }
@ -265,6 +266,17 @@ class GBLinear : public GradientBooster {
} }
} }
void PredictInteractionContributions(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs,
unsigned ntree_limit, bool approximate) override {
std::vector<bst_float>& contribs = *out_contribs;
// linear models have no interaction effects
const size_t nelements = model.param.num_feature*model.param.num_feature;
contribs.resize(p_fmat->info().num_row * nelements * model.param.num_output_group);
std::fill(contribs.begin(), contribs.end(), 0);
}
std::vector<std::string> DumpModel(const FeatureMap& fmap, std::vector<std::string> DumpModel(const FeatureMap& fmap,
bool with_stats, bool with_stats,
std::string format) const override { std::string format) const override {

View File

@ -220,10 +220,18 @@ class GBTree : public GradientBooster {
void PredictContribution(DMatrix* p_fmat, void PredictContribution(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs, std::vector<bst_float>* out_contribs,
unsigned ntree_limit, bool approximate) override { unsigned ntree_limit, bool approximate, int condition,
unsigned condition_feature) override {
predictor->PredictContribution(p_fmat, out_contribs, model_, ntree_limit, approximate); predictor->PredictContribution(p_fmat, out_contribs, model_, ntree_limit, approximate);
} }
void PredictInteractionContributions(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs,
unsigned ntree_limit, bool approximate) override {
predictor->PredictInteractionContributions(p_fmat, out_contribs, model_,
ntree_limit, approximate);
}
std::vector<std::string> DumpModel(const FeatureMap& fmap, std::vector<std::string> DumpModel(const FeatureMap& fmap,
bool with_stats, bool with_stats,
std::string format) const override { std::string format) const override {

View File

@ -443,9 +443,12 @@ class LearnerImpl : public Learner {
void Predict(DMatrix* data, bool output_margin, void Predict(DMatrix* data, bool output_margin,
std::vector<bst_float>* out_preds, unsigned ntree_limit, std::vector<bst_float>* out_preds, unsigned ntree_limit,
bool pred_leaf, bool pred_contribs, bool approx_contribs) const override { bool pred_leaf, bool pred_contribs, bool approx_contribs,
bool pred_interactions) const override {
if (pred_contribs) { if (pred_contribs) {
gbm_->PredictContribution(data, out_preds, ntree_limit, approx_contribs); gbm_->PredictContribution(data, out_preds, ntree_limit, approx_contribs);
} else if (pred_interactions) {
gbm_->PredictInteractionContributions(data, out_preds, ntree_limit, approx_contribs);
} else if (pred_leaf) { } else if (pred_leaf) {
gbm_->PredictLeaf(data, out_preds, ntree_limit); gbm_->PredictLeaf(data, out_preds, ntree_limit);
} else { } else {

View File

@ -304,6 +304,52 @@ struct EvalMAP : public EvalRankList {
} }
}; };
/*! \brief Cox: Partial likelihood of the Cox proportional hazards model */
struct EvalCox : public Metric {
public:
EvalCox() {}
bst_float Eval(const std::vector<bst_float> &preds,
const MetaInfo &info,
bool distributed) const override {
CHECK(!distributed) << "Cox metric does not support distributed evaluation";
using namespace std; // NOLINT(*)
const bst_omp_uint ndata = static_cast<bst_omp_uint>(info.labels.size());
const std::vector<size_t> &label_order = info.LabelAbsSort();
// pre-compute a sum for the denominator
double exp_p_sum = 0; // we use double because we might need the precision with large datasets
for (omp_ulong i = 0; i < ndata; ++i) {
exp_p_sum += preds[i];
}
double out = 0;
double accumulated_sum = 0;
bst_omp_uint num_events = 0;
for (bst_omp_uint i = 0; i < ndata; ++i) {
const size_t ind = label_order[i];
const auto label = info.labels[ind];
if (label > 0) {
out -= log(preds[ind]) - log(exp_p_sum);
++num_events;
}
// only update the denominator after we move forward in time (labels are sorted)
accumulated_sum += preds[ind];
if (i == ndata - 1 || std::abs(label) < std::abs(info.labels[label_order[i + 1]])) {
exp_p_sum -= accumulated_sum;
accumulated_sum = 0;
}
}
return out/num_events; // normalize by the number of events
}
const char* Name() const override {
return "cox-nloglik";
}
};
XGBOOST_REGISTER_METRIC(AMS, "ams") XGBOOST_REGISTER_METRIC(AMS, "ams")
.describe("AMS metric for higgs.") .describe("AMS metric for higgs.")
.set_body([](const char* param) { return new EvalAMS(param); }); .set_body([](const char* param) { return new EvalAMS(param); });
@ -323,5 +369,9 @@ XGBOOST_REGISTER_METRIC(NDCG, "ndcg")
XGBOOST_REGISTER_METRIC(MAP, "map") XGBOOST_REGISTER_METRIC(MAP, "map")
.describe("map@k for rank.") .describe("map@k for rank.")
.set_body([](const char* param) { return new EvalMAP(param); }); .set_body([](const char* param) { return new EvalMAP(param); });
XGBOOST_REGISTER_METRIC(Cox, "cox-nloglik")
.describe("Negative log partial likelihood of Cox proportioanl hazards model.")
.set_body([](const char* param) { return new EvalCox(); });
} // namespace metric } // namespace metric
} // namespace xgboost } // namespace xgboost

View File

@ -197,6 +197,90 @@ XGBOOST_REGISTER_OBJECTIVE(PoissonRegression, "count:poisson")
.describe("Possion regression for count data.") .describe("Possion regression for count data.")
.set_body([]() { return new PoissonRegression(); }); .set_body([]() { return new PoissonRegression(); });
// cox regression for survival data (negative values mean they are censored)
class CoxRegression : public ObjFunction {
public:
// declare functions
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {}
void GetGradient(const std::vector<bst_float> &preds,
const MetaInfo &info,
int iter,
std::vector<bst_gpair> *out_gpair) override {
CHECK_NE(info.labels.size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.size(), info.labels.size()) << "labels are not correctly provided";
out_gpair->resize(preds.size());
const std::vector<size_t> &label_order = info.LabelAbsSort();
const omp_ulong ndata = static_cast<omp_ulong>(preds.size()); // NOLINT(*)
// pre-compute a sum
double exp_p_sum = 0; // we use double because we might need the precision with large datasets
for (omp_ulong i = 0; i < ndata; ++i) {
exp_p_sum += std::exp(preds[label_order[i]]);
}
// start calculating grad and hess
double r_k = 0;
double s_k = 0;
double last_exp_p = 0.0;
double last_abs_y = 0.0;
double accumulated_sum = 0;
for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*)
const size_t ind = label_order[i];
const double p = preds[ind];
const double exp_p = std::exp(p);
const double w = info.GetWeight(ind);
const double y = info.labels[ind];
const double abs_y = std::abs(y);
// only update the denominator after we move forward in time (labels are sorted)
// this is Breslow's method for ties
accumulated_sum += last_exp_p;
if (last_abs_y < abs_y) {
exp_p_sum -= accumulated_sum;
accumulated_sum = 0;
} else {
CHECK(last_abs_y <= abs_y) << "CoxRegression: labels must be in sorted order, " <<
"MetaInfo::LabelArgsort failed!";
}
if (y > 0) {
r_k += 1.0/exp_p_sum;
s_k += 1.0/(exp_p_sum*exp_p_sum);
}
const double grad = exp_p*r_k - static_cast<bst_float>(y > 0);
const double hess = exp_p*r_k - exp_p*exp_p * s_k;
out_gpair->at(ind) = bst_gpair(grad * w, hess * w);
last_abs_y = abs_y;
last_exp_p = exp_p;
}
}
void PredTransform(std::vector<bst_float> *io_preds) override {
std::vector<bst_float> &preds = *io_preds;
const long ndata = static_cast<long>(preds.size()); // NOLINT(*)
#pragma omp parallel for schedule(static)
for (long j = 0; j < ndata; ++j) { // NOLINT(*)
preds[j] = std::exp(preds[j]);
}
}
void EvalTransform(std::vector<bst_float> *io_preds) override {
PredTransform(io_preds);
}
bst_float ProbToMargin(bst_float base_score) const override {
return std::log(base_score);
}
const char* DefaultEvalMetric(void) const override {
return "cox-nloglik";
}
};
// register the objective function
XGBOOST_REGISTER_OBJECTIVE(CoxRegression, "survival:cox")
.describe("Cox regression for censored survival data (negative labels are considered censored).")
.set_body([]() { return new CoxRegression(); });
// gamma regression // gamma regression
class GammaRegression : public ObjFunction { class GammaRegression : public ObjFunction {
public: public:

View File

@ -215,7 +215,9 @@ class CPUPredictor : public Predictor {
void PredictContribution(DMatrix* p_fmat, std::vector<bst_float>* out_contribs, void PredictContribution(DMatrix* p_fmat, std::vector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, unsigned ntree_limit, const gbm::GBTreeModel& model, unsigned ntree_limit,
bool approximate) override { bool approximate,
int condition,
unsigned condition_feature) override {
const int nthread = omp_get_max_threads(); const int nthread = omp_get_max_threads();
InitThreadTemp(nthread, model.param.num_feature); InitThreadTemp(nthread, model.param.num_feature);
const MetaInfo& info = p_fmat->info(); const MetaInfo& info = p_fmat->info();
@ -232,12 +234,10 @@ class CPUPredictor : public Predictor {
// make sure contributions is zeroed, we could be reusing a previously // make sure contributions is zeroed, we could be reusing a previously
// allocated one // allocated one
std::fill(contribs.begin(), contribs.end(), 0); std::fill(contribs.begin(), contribs.end(), 0);
if (approximate) { // initialize tree node mean values
// initialize tree node mean values #pragma omp parallel for schedule(static)
#pragma omp parallel for schedule(static) for (bst_omp_uint i = 0; i < ntree_limit; ++i) {
for (bst_omp_uint i = 0; i < ntree_limit; ++i) { model.trees[i]->FillNodeMeanValues();
model.trees[i]->FillNodeMeanValues();
}
} }
// start collecting the contributions // start collecting the contributions
dmlc::DataIter<RowBatch>* iter = p_fmat->RowIterator(); dmlc::DataIter<RowBatch>* iter = p_fmat->RowIterator();
@ -263,7 +263,8 @@ class CPUPredictor : public Predictor {
continue; continue;
} }
if (!approximate) { if (!approximate) {
model.trees[j]->CalculateContributions(feats, root_id, p_contribs); model.trees[j]->CalculateContributions(feats, root_id, p_contribs,
condition, condition_feature);
} else { } else {
model.trees[j]->CalculateContributionsApprox(feats, root_id, p_contribs); model.trees[j]->CalculateContributionsApprox(feats, root_id, p_contribs);
} }
@ -279,6 +280,50 @@ class CPUPredictor : public Predictor {
} }
} }
} }
void PredictInteractionContributions(DMatrix* p_fmat, std::vector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, unsigned ntree_limit,
bool approximate) override {
const MetaInfo& info = p_fmat->info();
const int ngroup = model.param.num_output_group;
size_t ncolumns = model.param.num_feature;
const unsigned row_chunk = ngroup * (ncolumns + 1) * (ncolumns + 1);
const unsigned mrow_chunk = (ncolumns + 1) * (ncolumns + 1);
const unsigned crow_chunk = ngroup * (ncolumns + 1);
// allocate space for (number of features^2) times the number of rows and tmp off/on contribs
std::vector<bst_float>& contribs = *out_contribs;
contribs.resize(info.num_row * ngroup * (ncolumns + 1) * (ncolumns + 1));
std::vector<bst_float> contribs_off(info.num_row * ngroup * (ncolumns + 1));
std::vector<bst_float> contribs_on(info.num_row * ngroup * (ncolumns + 1));
std::vector<bst_float> contribs_diag(info.num_row * ngroup * (ncolumns + 1));
// Compute the difference in effects when conditioning on each of the features on and off
// see: Axiomatic characterizations of probabilistic and
// cardinal-probabilistic interaction indices
PredictContribution(p_fmat, &contribs_diag, model, ntree_limit, approximate, 0, 0);
for (size_t i = 0; i < ncolumns + 1; ++i) {
PredictContribution(p_fmat, &contribs_off, model, ntree_limit, approximate, -1, i);
PredictContribution(p_fmat, &contribs_on, model, ntree_limit, approximate, 1, i);
for (size_t j = 0; j < info.num_row; ++j) {
for (int l = 0; l < ngroup; ++l) {
const unsigned o_offset = j * row_chunk + l * mrow_chunk + i * (ncolumns + 1);
const unsigned c_offset = j * crow_chunk + l * (ncolumns + 1);
contribs[o_offset + i] = 0;
for (size_t k = 0; k < ncolumns + 1; ++k) {
// fill in the diagonal with additive effects, and off-diagonal with the interactions
if (k == i) {
contribs[o_offset + i] += contribs_diag[c_offset + k];
} else {
contribs[o_offset + k] = (contribs_on[c_offset + k] - contribs_off[c_offset + k])/2.0;
contribs[o_offset + i] -= contribs[o_offset + k];
}
}
}
}
}
}
std::vector<RegTree::FVec> thread_temp; std::vector<RegTree::FVec> thread_temp;
}; };

View File

@ -454,10 +454,22 @@ class GPUPredictor : public xgboost::Predictor {
void PredictContribution(DMatrix* p_fmat, void PredictContribution(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs, std::vector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, unsigned ntree_limit, const gbm::GBTreeModel& model,
bool approximate) override { unsigned ntree_limit,
cpu_predictor->PredictContribution(p_fmat, out_contribs, model, ntree_limit, bool approximate,
approximate); int condition,
unsigned condition_feature) override {
cpu_predictor->PredictContribution(p_fmat, out_contribs, model,
ntree_limit, approximate, condition, condition_feature);
}
void PredictInteractionContributions(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs,
const gbm::GBTreeModel& model,
unsigned ntree_limit,
bool approximate) override {
cpu_predictor->PredictInteractionContributions(p_fmat, out_contribs, model,
ntree_limit, approximate);
} }
void Init(const std::vector<std::pair<std::string, std::string>>& cfg, void Init(const std::vector<std::pair<std::string, std::string>>& cfg,

View File

@ -172,3 +172,15 @@ TEST(Objective, TweedieRegressionBasic) {
EXPECT_NEAR(preds[i], out_preds[i], 0.01f); EXPECT_NEAR(preds[i], out_preds[i], 0.01f);
} }
} }
TEST(Objective, CoxRegressionGPair) {
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("survival:cox");
std::vector<std::pair<std::string, std::string> > args;
obj->Configure(args);
CheckObjFunction(obj,
{ 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1},
{ 0, -2, -2, 2, 3, 5, -10, 100},
{ 1, 1, 1, 1, 1, 1, 1, 1},
{ 0, 0, 0, -0.799f, -0.788f, -0.590f, 0.910f, 1.006f},
{ 0, 0, 0, 0.160f, 0.186f, 0.348f, 0.610f, 0.639f});
}

View File

@ -2,7 +2,6 @@
import numpy as np import numpy as np
import xgboost as xgb import xgboost as xgb
import unittest import unittest
import itertools
import json import json
dpath = 'demo/data/' dpath = 'demo/data/'
@ -143,35 +142,6 @@ class TestBasic(unittest.TestCase):
dm = xgb.DMatrix(dummy, feature_names=list('abcde')) dm = xgb.DMatrix(dummy, feature_names=list('abcde'))
self.assertRaises(ValueError, bst.predict, dm) self.assertRaises(ValueError, bst.predict, dm)
def test_feature_importances(self):
data = np.random.randn(100, 5)
target = np.array([0, 1] * 50)
features = ['Feature1', 'Feature2', 'Feature3', 'Feature4', 'Feature5']
dm = xgb.DMatrix(data, label=target,
feature_names=features)
params = {'objective': 'multi:softprob',
'eval_metric': 'mlogloss',
'eta': 0.3,
'num_class': 3}
bst = xgb.train(params, dm, num_boost_round=10)
# number of feature importances should == number of features
scores1 = bst.get_score()
scores2 = bst.get_score(importance_type='weight')
scores3 = bst.get_score(importance_type='cover')
scores4 = bst.get_score(importance_type='gain')
assert len(scores1) == len(features)
assert len(scores2) == len(features)
assert len(scores3) == len(features)
assert len(scores4) == len(features)
# check backwards compatibility of get_fscore
fscores = bst.get_fscore()
assert scores1 == fscores
def test_dump(self): def test_dump(self):
data = np.random.randn(100, 2) data = np.random.randn(100, 2)
target = np.array([0, 1] * 50) target = np.array([0, 1] * 50)
@ -268,41 +238,3 @@ class TestBasic(unittest.TestCase):
cv = xgb.cv(params, dm, num_boost_round=10, shuffle=False, nfold=10, as_pandas=False) cv = xgb.cv(params, dm, num_boost_round=10, shuffle=False, nfold=10, as_pandas=False)
assert isinstance(cv, dict) assert isinstance(cv, dict)
assert len(cv) == (4) assert len(cv) == (4)
def test_contributions():
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
def test_fn(max_depth, num_rounds):
# train
params = {'max_depth': max_depth, 'eta': 1, 'silent': 1}
bst = xgb.train(params, dtrain, num_boost_round=num_rounds)
# predict
preds = bst.predict(dtest)
contribs = bst.predict(dtest, pred_contribs=True)
# result should be (number of features + BIAS) * number of rows
assert contribs.shape == (dtest.num_row(), dtest.num_col() + 1)
# sum of contributions should be same as predictions
np.testing.assert_array_almost_equal(np.sum(contribs, axis=1), preds)
for max_depth, num_rounds in itertools.product(range(0, 3), range(1, 5)):
yield test_fn, max_depth, num_rounds
# check that we get the right SHAP values for a basic AND example
# (https://arxiv.org/abs/1706.06060)
X = np.zeros((4, 2))
X[0, :] = 1
X[1, 0] = 1
X[2, 1] = 1
y = np.zeros(4)
y[0] = 1
param = {"max_depth": 2, "base_score": 0.0, "eta": 1.0, "lambda": 0}
bst = xgb.train(param, xgb.DMatrix(X, label=y), 1)
out = bst.predict(xgb.DMatrix(X[0:1, :]), pred_contribs=True)
assert out[0, 0] == 0.375
assert out[0, 1] == 0.375
assert out[0, 2] == 0.25

252
tests/python/test_shap.py Normal file
View File

@ -0,0 +1,252 @@
# -*- coding: utf-8 -*-
import numpy as np
import xgboost as xgb
import unittest
import itertools
import re
import scipy
import scipy.special
dpath = 'demo/data/'
rng = np.random.RandomState(1994)
class TestSHAP(unittest.TestCase):
def test_feature_importances(self):
data = np.random.randn(100, 5)
target = np.array([0, 1] * 50)
features = ['Feature1', 'Feature2', 'Feature3', 'Feature4', 'Feature5']
dm = xgb.DMatrix(data, label=target,
feature_names=features)
params = {'objective': 'multi:softprob',
'eval_metric': 'mlogloss',
'eta': 0.3,
'num_class': 3}
bst = xgb.train(params, dm, num_boost_round=10)
# number of feature importances should == number of features
scores1 = bst.get_score()
scores2 = bst.get_score(importance_type='weight')
scores3 = bst.get_score(importance_type='cover')
scores4 = bst.get_score(importance_type='gain')
assert len(scores1) == len(features)
assert len(scores2) == len(features)
assert len(scores3) == len(features)
assert len(scores4) == len(features)
# check backwards compatibility of get_fscore
fscores = bst.get_fscore()
assert scores1 == fscores
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
def fn(max_depth, num_rounds):
# train
params = {'max_depth': max_depth, 'eta': 1, 'silent': 1}
bst = xgb.train(params, dtrain, num_boost_round=num_rounds)
# predict
preds = bst.predict(dtest)
contribs = bst.predict(dtest, pred_contribs=True)
# result should be (number of features + BIAS) * number of rows
assert contribs.shape == (dtest.num_row(), dtest.num_col() + 1)
# sum of contributions should be same as predictions
np.testing.assert_array_almost_equal(np.sum(contribs, axis=1), preds)
# for max_depth, num_rounds in itertools.product(range(0, 3), range(1, 5)):
# yield fn, max_depth, num_rounds
# check that we get the right SHAP values for a basic AND example
# (https://arxiv.org/abs/1706.06060)
X = np.zeros((4, 2))
X[0, :] = 1
X[1, 0] = 1
X[2, 1] = 1
y = np.zeros(4)
y[0] = 1
param = {"max_depth": 2, "base_score": 0.0, "eta": 1.0, "lambda": 0}
bst = xgb.train(param, xgb.DMatrix(X, label=y), 1)
out = bst.predict(xgb.DMatrix(X[0:1, :]), pred_contribs=True)
assert out[0, 0] == 0.375
assert out[0, 1] == 0.375
assert out[0, 2] == 0.25
def parse_model(model):
trees = []
r_exp = r"([0-9]+):\[f([0-9]+)<([0-9\.e-]+)\] yes=([0-9]+),no=([0-9]+).*cover=([0-9e\.]+)"
r_exp_leaf = r"([0-9]+):leaf=([0-9\.e-]+),cover=([0-9e\.]+)"
for tree in model.get_dump(with_stats=True):
lines = list(tree.splitlines())
trees.append([None for i in range(len(lines))])
for line in lines:
match = re.search(r_exp, line)
if match is not None:
ind = int(match.group(1))
while ind >= len(trees[-1]):
trees[-1].append(None)
trees[-1][ind] = {
"yes_ind": int(match.group(4)),
"no_ind": int(match.group(5)),
"value": None,
"threshold": float(match.group(3)),
"feature_index": int(match.group(2)),
"cover": float(match.group(6))
}
else:
match = re.search(r_exp_leaf, line)
ind = int(match.group(1))
while ind >= len(trees[-1]):
trees[-1].append(None)
trees[-1][ind] = {
"value": float(match.group(2)),
"cover": float(match.group(3))
}
return trees
def exp_value_rec(tree, z, x, i=0):
if tree[i]["value"] is not None:
return tree[i]["value"]
else:
ind = tree[i]["feature_index"]
if z[ind] == 1:
if x[ind] < tree[i]["threshold"]:
return exp_value_rec(tree, z, x, tree[i]["yes_ind"])
else:
return exp_value_rec(tree, z, x, tree[i]["no_ind"])
else:
r_yes = tree[tree[i]["yes_ind"]]["cover"] / tree[i]["cover"]
out = exp_value_rec(tree, z, x, tree[i]["yes_ind"])
val = out * r_yes
r_no = tree[tree[i]["no_ind"]]["cover"] / tree[i]["cover"]
out = exp_value_rec(tree, z, x, tree[i]["no_ind"])
val += out * r_no
return val
def exp_value(trees, z, x):
return np.sum([exp_value_rec(tree, z, x) for tree in trees])
def all_subsets(ss):
return itertools.chain(*map(lambda x: itertools.combinations(ss, x), range(0, len(ss) + 1)))
def shap_value(trees, x, i, cond=None, cond_value=None):
M = len(x)
z = np.zeros(M)
other_inds = list(set(range(M)) - set([i]))
if cond is not None:
other_inds = list(set(other_inds) - set([cond]))
z[cond] = cond_value
M -= 1
total = 0.0
for subset in all_subsets(other_inds):
if len(subset) > 0:
z[list(subset)] = 1
v1 = exp_value(trees, z, x)
z[i] = 1
v2 = exp_value(trees, z, x)
total += (v2 - v1) / (scipy.special.binom(M - 1, len(subset)) * M)
z[i] = 0
z[list(subset)] = 0
return total
def shap_values(trees, x):
vals = [shap_value(trees, x, i) for i in range(len(x))]
vals.append(exp_value(trees, np.zeros(len(x)), x))
return np.array(vals)
def interaction_values(trees, x):
M = len(x)
out = np.zeros((M + 1, M + 1))
for i in range(len(x)):
for j in range(len(x)):
if i != j:
out[i, j] = interaction_value(trees, x, i, j) / 2
svals = shap_values(trees, x)
main_effects = svals - out.sum(1)
out[np.diag_indices_from(out)] = main_effects
return out
def interaction_value(trees, x, i, j):
M = len(x)
z = np.zeros(M)
other_inds = list(set(range(M)) - set([i, j]))
total = 0.0
for subset in all_subsets(other_inds):
if len(subset) > 0:
z[list(subset)] = 1
v00 = exp_value(trees, z, x)
z[i] = 1
v10 = exp_value(trees, z, x)
z[j] = 1
v11 = exp_value(trees, z, x)
z[i] = 0
v01 = exp_value(trees, z, x)
z[j] = 0
total += (v11 - v01 - v10 + v00) / (scipy.special.binom(M - 2, len(subset)) * (M - 1))
z[list(subset)] = 0
return total
# test a simple and function
M = 2
N = 4
X = np.zeros((N, M))
X[0, :] = 1
X[1, 0] = 1
X[2, 1] = 1
y = np.zeros(N)
y[0] = 1
param = {"max_depth": 2, "base_score": 0.0, "eta": 1.0, "lambda": 0}
bst = xgb.train(param, xgb.DMatrix(X, label=y), 1)
brute_force = shap_values(parse_model(bst), X[0, :])
fast_method = bst.predict(xgb.DMatrix(X[0:1, :]), pred_contribs=True)
assert np.linalg.norm(brute_force - fast_method[0, :]) < 1e-4
brute_force = interaction_values(parse_model(bst), X[0, :])
fast_method = bst.predict(xgb.DMatrix(X[0:1, :]), pred_interactions=True)
assert np.linalg.norm(brute_force - fast_method[0, :, :]) < 1e-4
# test a random function
np.random.seed(0)
M = 2
N = 4
X = np.random.randn(N, M)
y = np.random.randn(N)
param = {"max_depth": 2, "base_score": 0.0, "eta": 1.0, "lambda": 0}
bst = xgb.train(param, xgb.DMatrix(X, label=y), 1)
brute_force = shap_values(parse_model(bst), X[0, :])
fast_method = bst.predict(xgb.DMatrix(X[0:1, :]), pred_contribs=True)
assert np.linalg.norm(brute_force - fast_method[0, :]) < 1e-4
brute_force = interaction_values(parse_model(bst), X[0, :])
fast_method = bst.predict(xgb.DMatrix(X[0:1, :]), pred_interactions=True)
assert np.linalg.norm(brute_force - fast_method[0, :, :]) < 1e-4
# test another larger more complex random function
np.random.seed(0)
M = 5
N = 100
X = np.random.randn(N, M)
y = np.random.randn(N)
base_score = 1.0
param = {"max_depth": 5, "base_score": base_score, "eta": 0.1, "gamma": 2.0}
bst = xgb.train(param, xgb.DMatrix(X, label=y), 10)
brute_force = shap_values(parse_model(bst), X[0, :])
brute_force[-1] += base_score
fast_method = bst.predict(xgb.DMatrix(X[0:1, :]), pred_contribs=True)
assert np.linalg.norm(brute_force - fast_method[0, :]) < 1e-4
brute_force = interaction_values(parse_model(bst), X[0, :])
brute_force[-1, -1] += base_score
fast_method = bst.predict(xgb.DMatrix(X[0:1, :]), pred_interactions=True)
assert np.linalg.norm(brute_force - fast_method[0, :, :]) < 1e-4

View File

@ -103,7 +103,7 @@ if [ ${TASK} == "cmake_test" ]; then
# Build/test without AVX # Build/test without AVX
mkdir build && cd build mkdir build && cd build
cmake .. -DGOOGLE_TEST=ON cmake .. -DGOOGLE_TEST=ON
make make
cd .. cd ..
./testxgboost ./testxgboost