Add prediction of feature contributions (#2003)
* Add prediction of feature contributions This implements the idea described at http://blog.datadive.net/interpreting-random-forests/ which tries to give insight in how a prediction is composed of its feature contributions and a bias. * Support multi-class models * Calculate learning_rate per-tree instead of using the one from the first tree * Do not rely on node.base_weight * learning_rate having the same value as the node mean value (aka leaf value, if it were a leaf); instead calculate them (lazily) on-the-fly * Add simple test for contributions feature * Check against param.num_nodes instead of checking for non-zero length * Loop over all roots instead of only the first
This commit is contained in:
committed by
Vadim Khotilovich
parent
e62be19c70
commit
6bd1869026
@@ -622,7 +622,8 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
||||
static_cast<std::shared_ptr<DMatrix>*>(dmat)->get(),
|
||||
(option_mask & 1) != 0,
|
||||
&preds, ntree_limit,
|
||||
(option_mask & 2) != 0);
|
||||
(option_mask & 2) != 0,
|
||||
(option_mask & 4) != 0);
|
||||
*out_result = dmlc::BeginPtr(preds);
|
||||
*len = static_cast<xgboost::bst_ulong>(preds.size());
|
||||
API_END();
|
||||
|
||||
@@ -223,6 +223,11 @@ class GBLinear : public GradientBooster {
|
||||
unsigned ntree_limit) override {
|
||||
LOG(FATAL) << "gblinear does not support predict leaf index";
|
||||
}
|
||||
void PredictContribution(DMatrix* p_fmat,
|
||||
std::vector<bst_float>* out_contribs,
|
||||
unsigned ntree_limit) override {
|
||||
LOG(FATAL) << "gblinear does not support predict contributions";
|
||||
}
|
||||
|
||||
std::vector<std::string> DumpModel(const FeatureMap& fmap,
|
||||
bool with_stats,
|
||||
|
||||
@@ -322,6 +322,14 @@ class GBTree : public GradientBooster {
|
||||
this->PredPath(p_fmat, out_preds, ntree_limit);
|
||||
}
|
||||
|
||||
void PredictContribution(DMatrix* p_fmat,
|
||||
std::vector<bst_float>* out_contribs,
|
||||
unsigned ntree_limit) override {
|
||||
const int nthread = omp_get_max_threads();
|
||||
InitThreadTemp(nthread);
|
||||
this->PredContrib(p_fmat, out_contribs, ntree_limit);
|
||||
}
|
||||
|
||||
std::vector<std::string> DumpModel(const FeatureMap& fmap,
|
||||
bool with_stats,
|
||||
std::string format) const override {
|
||||
@@ -553,6 +561,62 @@ class GBTree : public GradientBooster {
|
||||
}
|
||||
}
|
||||
}
|
||||
// predict contributions
|
||||
inline void PredContrib(DMatrix *p_fmat,
|
||||
std::vector<bst_float> *out_contribs,
|
||||
unsigned ntree_limit) {
|
||||
const MetaInfo& info = p_fmat->info();
|
||||
// number of valid trees
|
||||
ntree_limit *= mparam.num_output_group;
|
||||
if (ntree_limit == 0 || ntree_limit > trees.size()) {
|
||||
ntree_limit = static_cast<unsigned>(trees.size());
|
||||
}
|
||||
size_t ncolumns = mparam.num_feature + 1;
|
||||
// allocate space for (number of features + bias) times the number of rows
|
||||
std::vector<bst_float>& contribs = *out_contribs;
|
||||
contribs.resize(info.num_row * ncolumns * mparam.num_output_group);
|
||||
// make sure contributions is zeroed, we could be reusing a previously allocated one
|
||||
std::fill(contribs.begin(), contribs.end(), 0);
|
||||
// initialize tree node mean values
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i=0; i < ntree_limit; ++i) {
|
||||
trees[i]->FillNodeMeanValues();
|
||||
}
|
||||
// start collecting the contributions
|
||||
dmlc::DataIter<RowBatch>* iter = p_fmat->RowIterator();
|
||||
const std::vector<bst_float>& base_margin = p_fmat->info().base_margin;
|
||||
iter->BeforeFirst();
|
||||
while (iter->Next()) {
|
||||
const RowBatch& batch = iter->Value();
|
||||
// parallel over local batch
|
||||
const bst_omp_uint nsize = static_cast<bst_omp_uint>(batch.size);
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint i = 0; i < nsize; ++i) {
|
||||
size_t row_idx = static_cast<size_t>(batch.base_rowid + i);
|
||||
unsigned root_id = info.GetRoot(row_idx);
|
||||
RegTree::FVec &feats = thread_temp[omp_get_thread_num()];
|
||||
// loop over all classes
|
||||
for (int gid = 0; gid < mparam.num_output_group; ++gid) {
|
||||
bst_float *p_contribs = &contribs[(row_idx * mparam.num_output_group + gid) * ncolumns];
|
||||
feats.Fill(batch[i]);
|
||||
// calculate contributions
|
||||
for (unsigned j = 0; j < ntree_limit; ++j) {
|
||||
if (tree_info[j] != gid) {
|
||||
continue;
|
||||
}
|
||||
trees[j]->CalculateContributions(feats, root_id, p_contribs);
|
||||
}
|
||||
feats.Drop(batch[i]);
|
||||
// add base margin to BIAS feature
|
||||
if (base_margin.size() != 0) {
|
||||
p_contribs[ncolumns - 1] += base_margin[row_idx * mparam.num_output_group + gid];
|
||||
} else {
|
||||
p_contribs[ncolumns - 1] += base_margin_;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// init thread buffers
|
||||
inline void InitThreadTemp(int nthread) {
|
||||
int prev_thread_temp_size = thread_temp.size();
|
||||
|
||||
@@ -400,8 +400,11 @@ class LearnerImpl : public Learner {
|
||||
bool output_margin,
|
||||
std::vector<bst_float> *out_preds,
|
||||
unsigned ntree_limit,
|
||||
bool pred_leaf) const override {
|
||||
if (pred_leaf) {
|
||||
bool pred_leaf,
|
||||
bool pred_contribs) const override {
|
||||
if (pred_contribs) {
|
||||
gbm_->PredictContribution(data, out_preds, ntree_limit);
|
||||
} else if (pred_leaf) {
|
||||
gbm_->PredictLeaf(data, out_preds, ntree_limit);
|
||||
} else {
|
||||
this->PredictRaw(data, out_preds, ntree_limit);
|
||||
|
||||
Reference in New Issue
Block a user