Add total_gain and total_cover importance measures (#3498)

Add `'total_gain'` and `'total_cover'` as possible `importance_type`
arguments to `Booster.get_score` in the Python package.

`get_score` already accepts a `'gain'` argument, which returns each
feature's average gain over all of its splits.  `'total_gain'` does the
same, but returns a total rather than an average.  This seems more
intuitively meaningful, and also matches the behavior of the R package's
`xgb.importance` function.

I also added an analogous `'total_cover'` command for consistency.

This should resolve #3484.
This commit is contained in:
jqmp 2018-07-23 03:30:55 -04:00 committed by Philip Hyunsu Cho
parent a1505de631
commit e9a97e0d88
2 changed files with 26 additions and 7 deletions

View File

@ -1336,17 +1336,23 @@ class Booster(object):
"""Get feature importance of each feature. """Get feature importance of each feature.
Importance type can be defined as: Importance type can be defined as:
'weight' - the number of times a feature is used to split the data across all trees. 'weight' - the number of times a feature is used to split the data across all trees.
'gain' - the average gain of the feature when it is used in trees 'gain' - the average gain across all splits the feature is used in.
'cover' - the average coverage of the feature when it is used in trees 'cover' - the average coverage across all splits the feature is used in.
'total_gain' - the total gain across all splits the feature is used in.
'total_cover' - the total coverage across all splits the feature is used in.
Parameters Parameters
---------- ----------
fmap: str (optional) fmap: str (optional)
The name of feature map file The name of feature map file.
importance_type: str, default 'weight'
One of the importance types defined above.
""" """
if importance_type not in ['weight', 'gain', 'cover']: allowed_importance_types = ['weight', 'gain', 'cover', 'total_gain', 'total_cover']
msg = "importance_type mismatch, got '{}', expected 'weight', 'gain', or 'cover'" if importance_type not in allowed_importance_types:
msg = ("importance_type mismatch, got '{}', expected one of " +
repr(allowed_importance_types))
raise ValueError(msg.format(importance_type)) raise ValueError(msg.format(importance_type))
# if it's weight, then omap stores the number of missing values # if it's weight, then omap stores the number of missing values
@ -1375,6 +1381,14 @@ class Booster(object):
return fmap return fmap
else: else:
average_over_splits = True
if importance_type == 'total_gain':
importance_type = 'gain'
average_over_splits = False
elif importance_type == 'total_cover':
importance_type = 'cover'
average_over_splits = False
trees = self.get_dump(fmap, with_stats=True) trees = self.get_dump(fmap, with_stats=True)
importance_type += '=' importance_type += '='
@ -1406,8 +1420,9 @@ class Booster(object):
gmap[fid] += g gmap[fid] += g
# calculate average value (gain/cover) for each feature # calculate average value (gain/cover) for each feature
for fid in gmap: if average_over_splits:
gmap[fid] = gmap[fid] / fmap[fid] for fid in gmap:
gmap[fid] = gmap[fid] / fmap[fid]
return gmap return gmap

View File

@ -33,10 +33,14 @@ class TestSHAP(unittest.TestCase):
scores2 = bst.get_score(importance_type='weight') scores2 = bst.get_score(importance_type='weight')
scores3 = bst.get_score(importance_type='cover') scores3 = bst.get_score(importance_type='cover')
scores4 = bst.get_score(importance_type='gain') scores4 = bst.get_score(importance_type='gain')
scores5 = bst.get_score(importance_type='total_cover')
scores6 = bst.get_score(importance_type='total_gain')
assert len(scores1) == len(features) assert len(scores1) == len(features)
assert len(scores2) == len(features) assert len(scores2) == len(features)
assert len(scores3) == len(features) assert len(scores3) == len(features)
assert len(scores4) == len(features) assert len(scores4) == len(features)
assert len(scores5) == len(features)
assert len(scores6) == len(features)
# check backwards compatibility of get_fscore # check backwards compatibility of get_fscore
fscores = bst.get_fscore() fscores = bst.get_fscore()