Added other feature importances in python package (#1135)

* added new function to calculate other feature importances

* added capability to plot other feature importance measures

* changed plotting default to fscore

* added info on importance_type to boilerplate comment

* updated text of error statement

* added self module name to fix call

* added unit test for feature importances

* style fixes
This commit is contained in:
Alistair Johnson 2016-05-02 13:25:24 -04:00 committed by Yuan (Terry) Tang
parent c2c61eefd9
commit 6750c8b743
4 changed files with 121 additions and 17 deletions

View File

@ -1024,20 +1024,87 @@ class Booster(object):
fmap: str (optional)
The name of feature map file
"""
trees = self.get_dump(fmap)
fmap = {}
for tree in trees:
for line in tree.split('\n'):
arr = line.split('[')
if len(arr) == 1:
continue
fid = arr[1].split(']')[0]
fid = fid.split('<')[0]
if fid not in fmap:
fmap[fid] = 1
else:
fmap[fid] += 1
return fmap
return self.get_score(fmap, importance_type='weight')
def get_score(self, fmap='', importance_type='weight'):
"""Get feature importance of each feature.
Importance type can be defined as:
'weight' - the number of times a feature is used to split the data across all trees.
'gain' - the average gain of the feature when it is used in trees
'cover' - the average coverage of the feature when it is used in trees
Parameters
----------
fmap: str (optional)
The name of feature map file
"""
if importance_type not in ['weight', 'gain', 'cover']:
msg = "importance_type mismatch, got '{}', expected 'weight', 'gain', or 'cover'"
raise ValueError(msg.format(importance_type))
# if it's weight, then omap stores the number of missing values
if importance_type == 'weight':
# do a simpler tree dump to save time
trees = self.get_dump(fmap, with_stats=False)
fmap = {}
for tree in trees:
for line in tree.split('\n'):
# look for the opening square bracket
arr = line.split('[')
# if no opening bracket (leaf node), ignore this line
if len(arr) == 1:
continue
# extract feature name from string between []
fid = arr[1].split(']')[0].split('<')[0]
if fid not in fmap:
# if the feature hasn't been seen yet
fmap[fid] = 1
else:
fmap[fid] += 1
return fmap
else:
trees = self.get_dump(fmap, with_stats=True)
importance_type += '='
fmap = {}
gmap = {}
for tree in trees:
for line in tree.split('\n'):
# look for the opening square bracket
arr = line.split('[')
# if no opening bracket (leaf node), ignore this line
if len(arr) == 1:
continue
# look for the closing bracket, extract only info within that bracket
fid = arr[1].split(']')
# extract gain or cover from string after closing bracket
g = float(fid[1].split(importance_type)[1].split(',')[0])
# extract feature name from string before closing bracket
fid = fid[0].split('<')[0]
if fid not in fmap:
# if the feature hasn't been seen yet
fmap[fid] = 1
gmap[fid] = g
else:
fmap[fid] += 1
gmap[fid] += g
# calculate average value (gain/cover) for each feature
for fid in gmap:
gmap[fid] = gmap[fid] / fmap[fid]
return gmap
def _validate_features(self, data):
"""

View File

@ -14,6 +14,7 @@ from .sklearn import XGBModel
def plot_importance(booster, ax=None, height=0.2,
xlim=None, ylim=None, title='Feature importance',
xlabel='F score', ylabel='Features',
importance_type='weight',
grid=True, **kwargs):
"""Plot importance based on fitted trees.
@ -24,6 +25,12 @@ def plot_importance(booster, ax=None, height=0.2,
Booster or XGBModel instance, or dict taken by Booster.get_fscore()
ax : matplotlib Axes, default None
Target axes instance. If None, new figure and axes will be created.
importance_type : str, default "weight"
How the importance is calculated: either "weight", "gain", or "cover"
"weight" is the number of times a feature appears in a tree
"gain" is the average gain of splits which use the feature
"cover" is the average coverage of splits which use the feature
where coverage is defined as the number of samples affected by the split
height : float, default 0.2
Bar height, passed to ax.barh()
xlim : tuple, default None
@ -50,16 +57,16 @@ def plot_importance(booster, ax=None, height=0.2,
raise ImportError('You must install matplotlib to plot importance')
if isinstance(booster, XGBModel):
importance = booster.booster().get_fscore()
importance = booster.booster().get_score(importance_type=importance_type)
elif isinstance(booster, Booster):
importance = booster.get_fscore()
importance = booster.get_score(importance_type=importance_type)
elif isinstance(booster, dict):
importance = booster
else:
raise ValueError('tree must be Booster, XGBModel or dict instance')
if len(importance) == 0:
raise ValueError('Booster.get_fscore() results in empty')
raise ValueError('Booster.get_score() results in empty')
tuples = [(k, importance[k]) for k in importance]
tuples = sorted(tuples, key=lambda x: x[1])

View File

@ -109,6 +109,7 @@ class XGBModel(XGBModelBase):
hess: array_like of shape [n_samples]
The value of the second derivative for each sample point
"""
def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100,
silent=True, objective="reg:linear",
nthread=-1, gamma=0, min_child_weight=1, max_delta_step=0,

View File

@ -125,6 +125,35 @@ class TestBasic(unittest.TestCase):
dm = xgb.DMatrix(dummy, feature_names=list('abcde'))
self.assertRaises(ValueError, bst.predict, dm)
def test_feature_importances(self):
data = np.random.randn(100, 5)
target = np.array([0, 1] * 50)
features = ['Feature1', 'Feature2', 'Feature3', 'Feature4', 'Feature5']
dm = xgb.DMatrix(data, label=target,
feature_names=features)
params = {'objective': 'multi:softprob',
'eval_metric': 'mlogloss',
'eta': 0.3,
'num_class': 3}
bst = xgb.train(params, dm, num_boost_round=10)
# number of feature importances should == number of features
scores1 = bst.get_score()
scores2 = bst.get_score(importance_type='weight')
scores3 = bst.get_score(importance_type='cover')
scores4 = bst.get_score(importance_type='gain')
assert len(scores1) == len(features)
assert len(scores2) == len(features)
assert len(scores3) == len(features)
assert len(scores4) == len(features)
# check backwards compatibility of get_fscore
fscores = bst.get_fscore()
assert scores1 == fscores
def test_load_file_invalid(self):
self.assertRaises(xgb.core.XGBoostError, xgb.Booster,
model_file='incorrect_path')