diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index f99df9216..8e282cae9 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -1336,17 +1336,23 @@ class Booster(object): """Get feature importance of each feature. Importance type can be defined as: 'weight' - the number of times a feature is used to split the data across all trees. - 'gain' - the average gain of the feature when it is used in trees - 'cover' - the average coverage of the feature when it is used in trees + 'gain' - the average gain across all splits the feature is used in. + 'cover' - the average coverage across all splits the feature is used in. + 'total_gain' - the total gain across all splits the feature is used in. + 'total_cover' - the total coverage across all splits the feature is used in. Parameters ---------- fmap: str (optional) - The name of feature map file + The name of feature map file. + importance_type: str, default 'weight' + One of the importance types defined above. """ - if importance_type not in ['weight', 'gain', 'cover']: - msg = "importance_type mismatch, got '{}', expected 'weight', 'gain', or 'cover'" + allowed_importance_types = ['weight', 'gain', 'cover', 'total_gain', 'total_cover'] + if importance_type not in allowed_importance_types: + msg = ("importance_type mismatch, got '{}', expected one of " + + repr(allowed_importance_types)) raise ValueError(msg.format(importance_type)) # if it's weight, then omap stores the number of missing values @@ -1375,6 +1381,14 @@ class Booster(object): return fmap else: + average_over_splits = True + if importance_type == 'total_gain': + importance_type = 'gain' + average_over_splits = False + elif importance_type == 'total_cover': + importance_type = 'cover' + average_over_splits = False + trees = self.get_dump(fmap, with_stats=True) importance_type += '=' @@ -1406,8 +1420,9 @@ class Booster(object): gmap[fid] += g # calculate average value (gain/cover) for each feature - for fid in gmap: - gmap[fid] = gmap[fid] / fmap[fid] + if average_over_splits: + for fid in gmap: + gmap[fid] = gmap[fid] / fmap[fid] return gmap diff --git a/tests/python/test_shap.py b/tests/python/test_shap.py index e4256dcc8..3afe1630f 100644 --- a/tests/python/test_shap.py +++ b/tests/python/test_shap.py @@ -33,10 +33,14 @@ class TestSHAP(unittest.TestCase): scores2 = bst.get_score(importance_type='weight') scores3 = bst.get_score(importance_type='cover') scores4 = bst.get_score(importance_type='gain') + scores5 = bst.get_score(importance_type='total_cover') + scores6 = bst.get_score(importance_type='total_gain') assert len(scores1) == len(features) assert len(scores2) == len(features) assert len(scores3) == len(features) assert len(scores4) == len(features) + assert len(scores5) == len(features) + assert len(scores6) == len(features) # check backwards compatibility of get_fscore fscores = bst.get_fscore()