Add total_gain and total_cover importance measures (#3498)
Add `'total_gain'` and `'total_cover'` as possible `importance_type` arguments to `Booster.get_score` in the Python package. `get_score` already accepts a `'gain'` argument, which returns each feature's average gain over all of its splits. `'total_gain'` does the same, but returns a total rather than an average. This seems more intuitively meaningful, and also matches the behavior of the R package's `xgb.importance` function. I also added an analogous `'total_cover'` command for consistency. This should resolve #3484.
This commit is contained in:
parent
a1505de631
commit
e9a97e0d88
@ -1336,17 +1336,23 @@ class Booster(object):
|
|||||||
"""Get feature importance of each feature.
|
"""Get feature importance of each feature.
|
||||||
Importance type can be defined as:
|
Importance type can be defined as:
|
||||||
'weight' - the number of times a feature is used to split the data across all trees.
|
'weight' - the number of times a feature is used to split the data across all trees.
|
||||||
'gain' - the average gain of the feature when it is used in trees
|
'gain' - the average gain across all splits the feature is used in.
|
||||||
'cover' - the average coverage of the feature when it is used in trees
|
'cover' - the average coverage across all splits the feature is used in.
|
||||||
|
'total_gain' - the total gain across all splits the feature is used in.
|
||||||
|
'total_cover' - the total coverage across all splits the feature is used in.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
fmap: str (optional)
|
fmap: str (optional)
|
||||||
The name of feature map file
|
The name of feature map file.
|
||||||
|
importance_type: str, default 'weight'
|
||||||
|
One of the importance types defined above.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if importance_type not in ['weight', 'gain', 'cover']:
|
allowed_importance_types = ['weight', 'gain', 'cover', 'total_gain', 'total_cover']
|
||||||
msg = "importance_type mismatch, got '{}', expected 'weight', 'gain', or 'cover'"
|
if importance_type not in allowed_importance_types:
|
||||||
|
msg = ("importance_type mismatch, got '{}', expected one of " +
|
||||||
|
repr(allowed_importance_types))
|
||||||
raise ValueError(msg.format(importance_type))
|
raise ValueError(msg.format(importance_type))
|
||||||
|
|
||||||
# if it's weight, then omap stores the number of missing values
|
# if it's weight, then omap stores the number of missing values
|
||||||
@ -1375,6 +1381,14 @@ class Booster(object):
|
|||||||
return fmap
|
return fmap
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
average_over_splits = True
|
||||||
|
if importance_type == 'total_gain':
|
||||||
|
importance_type = 'gain'
|
||||||
|
average_over_splits = False
|
||||||
|
elif importance_type == 'total_cover':
|
||||||
|
importance_type = 'cover'
|
||||||
|
average_over_splits = False
|
||||||
|
|
||||||
trees = self.get_dump(fmap, with_stats=True)
|
trees = self.get_dump(fmap, with_stats=True)
|
||||||
|
|
||||||
importance_type += '='
|
importance_type += '='
|
||||||
@ -1406,8 +1420,9 @@ class Booster(object):
|
|||||||
gmap[fid] += g
|
gmap[fid] += g
|
||||||
|
|
||||||
# calculate average value (gain/cover) for each feature
|
# calculate average value (gain/cover) for each feature
|
||||||
for fid in gmap:
|
if average_over_splits:
|
||||||
gmap[fid] = gmap[fid] / fmap[fid]
|
for fid in gmap:
|
||||||
|
gmap[fid] = gmap[fid] / fmap[fid]
|
||||||
|
|
||||||
return gmap
|
return gmap
|
||||||
|
|
||||||
|
|||||||
@ -33,10 +33,14 @@ class TestSHAP(unittest.TestCase):
|
|||||||
scores2 = bst.get_score(importance_type='weight')
|
scores2 = bst.get_score(importance_type='weight')
|
||||||
scores3 = bst.get_score(importance_type='cover')
|
scores3 = bst.get_score(importance_type='cover')
|
||||||
scores4 = bst.get_score(importance_type='gain')
|
scores4 = bst.get_score(importance_type='gain')
|
||||||
|
scores5 = bst.get_score(importance_type='total_cover')
|
||||||
|
scores6 = bst.get_score(importance_type='total_gain')
|
||||||
assert len(scores1) == len(features)
|
assert len(scores1) == len(features)
|
||||||
assert len(scores2) == len(features)
|
assert len(scores2) == len(features)
|
||||||
assert len(scores3) == len(features)
|
assert len(scores3) == len(features)
|
||||||
assert len(scores4) == len(features)
|
assert len(scores4) == len(features)
|
||||||
|
assert len(scores5) == len(features)
|
||||||
|
assert len(scores6) == len(features)
|
||||||
|
|
||||||
# check backwards compatibility of get_fscore
|
# check backwards compatibility of get_fscore
|
||||||
fscores = bst.get_fscore()
|
fscores = bst.get_fscore()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user