From 6bd1869026eb66a542ccd8573f1e4b098a8cf7c6 Mon Sep 17 00:00:00 2001 From: Maurus Cuelenaere Date: Sun, 14 May 2017 07:58:10 +0200 Subject: [PATCH] Add prediction of feature contributions (#2003) * Add prediction of feature contributions This implements the idea described at http://blog.datadive.net/interpreting-random-forests/ which tries to give insight in how a prediction is composed of its feature contributions and a bias. * Support multi-class models * Calculate learning_rate per-tree instead of using the one from the first tree * Do not rely on node.base_weight * learning_rate having the same value as the node mean value (aka leaf value, if it were a leaf); instead calculate them (lazily) on-the-fly * Add simple test for contributions feature * Check against param.num_nodes instead of checking for non-zero length * Loop over all roots instead of only the first --- include/xgboost/c_api.h | 1 + include/xgboost/gbm.h | 13 ++++++ include/xgboost/learner.h | 4 +- include/xgboost/tree_model.h | 78 ++++++++++++++++++++++++++++++++++ python-package/xgboost/core.py | 11 ++++- src/c_api/c_api.cc | 3 +- src/gbm/gblinear.cc | 5 +++ src/gbm/gbtree.cc | 64 ++++++++++++++++++++++++++++ src/learner.cc | 7 ++- tests/python/test_basic.py | 24 +++++++++++ 10 files changed, 205 insertions(+), 5 deletions(-) diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 08d9ff312..f4792ad5a 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -382,6 +382,7 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle, * 0:normal prediction * 1:output margin instead of transformed value * 2:output leaf index of trees instead of leaf value, note leaf index is unique per tree + * 4:output feature contributions of all trees instead of predictions * \param ntree_limit limit number of trees used for prediction, this is only valid for boosted trees * when the parameter is set to 0, we will use all the trees * \param out_len used to store length of returning result diff --git a/include/xgboost/gbm.h b/include/xgboost/gbm.h index 10ede8500..aadd074f8 100644 --- a/include/xgboost/gbm.h +++ b/include/xgboost/gbm.h @@ -107,6 +107,19 @@ class GradientBooster { virtual void PredictLeaf(DMatrix* dmat, std::vector* out_preds, unsigned ntree_limit = 0) = 0; + + /*! + * \brief predict the feature contributions of each tree, the output will be nsample * (nfeats + 1) vector + * this is only valid in gbtree predictor + * \param dmat feature matrix + * \param out_contribs output vector to hold the contributions + * \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means + * we do not limit number of trees + */ + virtual void PredictContribution(DMatrix* dmat, + std::vector* out_contribs, + unsigned ntree_limit = 0) = 0; + /*! * \brief dump the model in the requested format * \param fmap feature map that may help give interpretations of feature diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index 86daa113b..4733f5522 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -103,12 +103,14 @@ class Learner : public rabit::Serializable { * \param ntree_limit limit number of trees used for boosted tree * predictor, when it equals 0, this means we are using all the trees * \param pred_leaf whether to only predict the leaf index of each tree in a boosted tree predictor + * \param pred_contribs whether to only predict the feature contributions of all trees */ virtual void Predict(DMatrix* data, bool output_margin, std::vector *out_preds, unsigned ntree_limit = 0, - bool pred_leaf = false) const = 0; + bool pred_leaf = false, + bool pred_contribs = false) const = 0; /*! * \brief Set additional attribute to the Booster. * The property will be saved along the booster. diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index 59d84b227..e9319c4ff 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -434,6 +434,11 @@ class RegTree: public TreeModel { * \param inst The sparse instance to drop. */ inline void Drop(const RowBatch::Inst& inst); + /*! + * \brief returns the size of the feature vector + * \return the size of the feature vector + */ + inline size_t size() const; /*! * \brief get ith value * \param i feature index. @@ -472,6 +477,14 @@ class RegTree: public TreeModel { * \return the leaf index of the given feature */ inline bst_float Predict(const FVec& feat, unsigned root_id = 0) const; + /*! + * \brief calculate the feature contributions for the given root + * \param feat dense feature vector, if the feature is missing the field is set to NaN + * \param root_id starting root index of the instance + * \param out_contribs output vector to hold the contributions + */ + inline void CalculateContributions(const RegTree::FVec& feat, unsigned root_id, + bst_float *out_contribs) const; /*! * \brief get next position of the tree given current pid * \param pid Current node id. @@ -489,6 +502,15 @@ class RegTree: public TreeModel { std::string DumpModel(const FeatureMap& fmap, bool with_stats, std::string format) const; + /*! + * \brief calculate the mean value for each node, required for feature contributions + */ + inline void FillNodeMeanValues(); + + private: + inline bst_float FillNodeMeanValue(int nid); + + std::vector node_mean_values; }; // implementations of inline functions @@ -513,6 +535,10 @@ inline void RegTree::FVec::Drop(const RowBatch::Inst& inst) { } } +inline size_t RegTree::FVec::size() const { + return data.size(); +} + inline bst_float RegTree::FVec::fvalue(size_t i) const { return data[i].fvalue; } @@ -535,6 +561,58 @@ inline bst_float RegTree::Predict(const RegTree::FVec& feat, unsigned root_id) c return (*this)[pid].leaf_value(); } +inline void RegTree::FillNodeMeanValues() { + size_t num_nodes = this->param.num_nodes; + if (this->node_mean_values.size() == num_nodes) { + return; + } + this->node_mean_values.resize(num_nodes); + for (int root_id = 0; root_id < param.num_roots; ++root_id) { + this->FillNodeMeanValue(root_id); + } +} + +inline bst_float RegTree::FillNodeMeanValue(int nid) { + bst_float result; + auto& node = (*this)[nid]; + if (node.is_leaf()) { + result = node.leaf_value(); + } else { + result = this->FillNodeMeanValue(node.cleft()) * this->stat(node.cleft()).sum_hess; + result += this->FillNodeMeanValue(node.cright()) * this->stat(node.cright()).sum_hess; + result /= this->stat(nid).sum_hess; + } + this->node_mean_values[nid] = result; + return result; +} + +inline void RegTree::CalculateContributions(const RegTree::FVec& feat, unsigned root_id, + bst_float *out_contribs) const { + CHECK_GT(this->node_mean_values.size(), 0); + // this follows the idea of http://blog.datadive.net/interpreting-random-forests/ + bst_float node_value; + unsigned split_index; + int pid = static_cast(root_id); + // update bias value + node_value = this->node_mean_values[pid]; + out_contribs[feat.size()] += node_value; + if ((*this)[pid].is_leaf()) { + // nothing to do anymore + return; + } + while (!(*this)[pid].is_leaf()) { + split_index = (*this)[pid].split_index(); + pid = this->GetNext(pid, feat.fvalue(split_index), feat.is_missing(split_index)); + bst_float new_value = this->node_mean_values[pid]; + // update feature weight + out_contribs[split_index] += new_value - node_value; + node_value = new_value; + } + bst_float leaf_value = (*this)[pid].leaf_value(); + // update leaf feature weight + out_contribs[split_index] += leaf_value - node_value; +} + /*! \brief get next position of the tree given current pid */ inline int RegTree::GetNext(int pid, bst_float fvalue, bool is_unknown) const { bst_float split_value = (*this)[pid].split_cond(); diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 0d46893b6..2d1c5e7c7 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -911,7 +911,8 @@ class Booster(object): self._validate_features(data) return self.eval_set([(data, name)], iteration) - def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False): + def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False, + pred_contribs=False): """ Predict with data. @@ -937,6 +938,12 @@ class Booster(object): Note that the leaf index of a tree is unique per tree, so you may find leaf 1 in both tree 1 and tree 0. + pred_contribs : bool + When this option is on, the output will be a matrix of (nsample, nfeats+1) + with each record indicating the feature contributions of all trees. The sum of + all feature contributions is equal to the prediction. Note that the bias is added + as the final column, on top of the regular features. + Returns ------- prediction : numpy array @@ -946,6 +953,8 @@ class Booster(object): option_mask |= 0x01 if pred_leaf: option_mask |= 0x02 + if pred_contribs: + option_mask |= 0x04 self._validate_features(data) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 5f9c3b07e..5a6ba6e5c 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -622,7 +622,8 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle, static_cast*>(dmat)->get(), (option_mask & 1) != 0, &preds, ntree_limit, - (option_mask & 2) != 0); + (option_mask & 2) != 0, + (option_mask & 4) != 0); *out_result = dmlc::BeginPtr(preds); *len = static_cast(preds.size()); API_END(); diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index cb2252256..168fe4b57 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -223,6 +223,11 @@ class GBLinear : public GradientBooster { unsigned ntree_limit) override { LOG(FATAL) << "gblinear does not support predict leaf index"; } + void PredictContribution(DMatrix* p_fmat, + std::vector* out_contribs, + unsigned ntree_limit) override { + LOG(FATAL) << "gblinear does not support predict contributions"; + } std::vector DumpModel(const FeatureMap& fmap, bool with_stats, diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index b54a304a7..1f79a58c4 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -322,6 +322,14 @@ class GBTree : public GradientBooster { this->PredPath(p_fmat, out_preds, ntree_limit); } + void PredictContribution(DMatrix* p_fmat, + std::vector* out_contribs, + unsigned ntree_limit) override { + const int nthread = omp_get_max_threads(); + InitThreadTemp(nthread); + this->PredContrib(p_fmat, out_contribs, ntree_limit); + } + std::vector DumpModel(const FeatureMap& fmap, bool with_stats, std::string format) const override { @@ -553,6 +561,62 @@ class GBTree : public GradientBooster { } } } + // predict contributions + inline void PredContrib(DMatrix *p_fmat, + std::vector *out_contribs, + unsigned ntree_limit) { + const MetaInfo& info = p_fmat->info(); + // number of valid trees + ntree_limit *= mparam.num_output_group; + if (ntree_limit == 0 || ntree_limit > trees.size()) { + ntree_limit = static_cast(trees.size()); + } + size_t ncolumns = mparam.num_feature + 1; + // allocate space for (number of features + bias) times the number of rows + std::vector& contribs = *out_contribs; + contribs.resize(info.num_row * ncolumns * mparam.num_output_group); + // make sure contributions is zeroed, we could be reusing a previously allocated one + std::fill(contribs.begin(), contribs.end(), 0); + // initialize tree node mean values + #pragma omp parallel for schedule(static) + for (bst_omp_uint i=0; i < ntree_limit; ++i) { + trees[i]->FillNodeMeanValues(); + } + // start collecting the contributions + dmlc::DataIter* iter = p_fmat->RowIterator(); + const std::vector& base_margin = p_fmat->info().base_margin; + iter->BeforeFirst(); + while (iter->Next()) { + const RowBatch& batch = iter->Value(); + // parallel over local batch + const bst_omp_uint nsize = static_cast(batch.size); + #pragma omp parallel for schedule(static) + for (bst_omp_uint i = 0; i < nsize; ++i) { + size_t row_idx = static_cast(batch.base_rowid + i); + unsigned root_id = info.GetRoot(row_idx); + RegTree::FVec &feats = thread_temp[omp_get_thread_num()]; + // loop over all classes + for (int gid = 0; gid < mparam.num_output_group; ++gid) { + bst_float *p_contribs = &contribs[(row_idx * mparam.num_output_group + gid) * ncolumns]; + feats.Fill(batch[i]); + // calculate contributions + for (unsigned j = 0; j < ntree_limit; ++j) { + if (tree_info[j] != gid) { + continue; + } + trees[j]->CalculateContributions(feats, root_id, p_contribs); + } + feats.Drop(batch[i]); + // add base margin to BIAS feature + if (base_margin.size() != 0) { + p_contribs[ncolumns - 1] += base_margin[row_idx * mparam.num_output_group + gid]; + } else { + p_contribs[ncolumns - 1] += base_margin_; + } + } + } + } + } // init thread buffers inline void InitThreadTemp(int nthread) { int prev_thread_temp_size = thread_temp.size(); diff --git a/src/learner.cc b/src/learner.cc index 8c92556fc..2622ff4fb 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -400,8 +400,11 @@ class LearnerImpl : public Learner { bool output_margin, std::vector *out_preds, unsigned ntree_limit, - bool pred_leaf) const override { - if (pred_leaf) { + bool pred_leaf, + bool pred_contribs) const override { + if (pred_contribs) { + gbm_->PredictContribution(data, out_preds, ntree_limit); + } else if (pred_leaf) { gbm_->PredictLeaf(data, out_preds, ntree_limit); } else { this->PredictRaw(data, out_preds, ntree_limit); diff --git a/tests/python/test_basic.py b/tests/python/test_basic.py index c56dfde3f..ffbeb518c 100644 --- a/tests/python/test_basic.py +++ b/tests/python/test_basic.py @@ -2,6 +2,7 @@ import numpy as np import xgboost as xgb import unittest +import itertools import json dpath = 'demo/data/' @@ -250,3 +251,26 @@ class TestBasic(unittest.TestCase): cv = xgb.cv(params, dm, num_boost_round=10, shuffle=False, nfold=10, as_pandas=False) assert isinstance(cv, dict) assert len(cv) == (4) + + +def test_contributions(): + dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') + dtest = xgb.DMatrix(dpath + 'agaricus.txt.test') + + def test_fn(max_depth, num_rounds): + # train + params = {'max_depth': max_depth, 'eta': 1, 'silent': 1} + bst = xgb.train(params, dtrain, num_boost_round=num_rounds) + + # predict + preds = bst.predict(dtest) + contribs = bst.predict(dtest, pred_contribs=True) + + # result should be (number of features + BIAS) * number of rows + assert contribs.shape == (dtest.num_row(), dtest.num_col() + 1) + + # sum of contributions should be same as predictions + np.testing.assert_array_almost_equal(np.sum(contribs, axis=1), preds) + + for max_depth, num_rounds in itertools.product(range(0, 3), range(1, 5)): + yield test_fn, max_depth, num_rounds