Add prediction of feature contributions (#2003)
* Add prediction of feature contributions This implements the idea described at http://blog.datadive.net/interpreting-random-forests/ which tries to give insight in how a prediction is composed of its feature contributions and a bias. * Support multi-class models * Calculate learning_rate per-tree instead of using the one from the first tree * Do not rely on node.base_weight * learning_rate having the same value as the node mean value (aka leaf value, if it were a leaf); instead calculate them (lazily) on-the-fly * Add simple test for contributions feature * Check against param.num_nodes instead of checking for non-zero length * Loop over all roots instead of only the first
This commit is contained in:
committed by
Vadim Khotilovich
parent
e62be19c70
commit
6bd1869026
@@ -2,6 +2,7 @@
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
import unittest
|
||||
import itertools
|
||||
import json
|
||||
|
||||
dpath = 'demo/data/'
|
||||
@@ -250,3 +251,26 @@ class TestBasic(unittest.TestCase):
|
||||
cv = xgb.cv(params, dm, num_boost_round=10, shuffle=False, nfold=10, as_pandas=False)
|
||||
assert isinstance(cv, dict)
|
||||
assert len(cv) == (4)
|
||||
|
||||
|
||||
def test_contributions():
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||
|
||||
def test_fn(max_depth, num_rounds):
|
||||
# train
|
||||
params = {'max_depth': max_depth, 'eta': 1, 'silent': 1}
|
||||
bst = xgb.train(params, dtrain, num_boost_round=num_rounds)
|
||||
|
||||
# predict
|
||||
preds = bst.predict(dtest)
|
||||
contribs = bst.predict(dtest, pred_contribs=True)
|
||||
|
||||
# result should be (number of features + BIAS) * number of rows
|
||||
assert contribs.shape == (dtest.num_row(), dtest.num_col() + 1)
|
||||
|
||||
# sum of contributions should be same as predictions
|
||||
np.testing.assert_array_almost_equal(np.sum(contribs, axis=1), preds)
|
||||
|
||||
for max_depth, num_rounds in itertools.product(range(0, 3), range(1, 5)):
|
||||
yield test_fn, max_depth, num_rounds
|
||||
|
||||
Reference in New Issue
Block a user