diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index eb71e287b..a33fe447b 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -1024,20 +1024,87 @@ class Booster(object): fmap: str (optional) The name of feature map file """ - trees = self.get_dump(fmap) - fmap = {} - for tree in trees: - for line in tree.split('\n'): - arr = line.split('[') - if len(arr) == 1: - continue - fid = arr[1].split(']')[0] - fid = fid.split('<')[0] - if fid not in fmap: - fmap[fid] = 1 - else: - fmap[fid] += 1 - return fmap + + return self.get_score(fmap, importance_type='weight') + + def get_score(self, fmap='', importance_type='weight'): + """Get feature importance of each feature. + Importance type can be defined as: + 'weight' - the number of times a feature is used to split the data across all trees. + 'gain' - the average gain of the feature when it is used in trees + 'cover' - the average coverage of the feature when it is used in trees + + Parameters + ---------- + fmap: str (optional) + The name of feature map file + """ + + if importance_type not in ['weight', 'gain', 'cover']: + msg = "importance_type mismatch, got '{}', expected 'weight', 'gain', or 'cover'" + raise ValueError(msg.format(importance_type)) + + # if it's weight, then omap stores the number of missing values + if importance_type == 'weight': + # do a simpler tree dump to save time + trees = self.get_dump(fmap, with_stats=False) + + fmap = {} + for tree in trees: + for line in tree.split('\n'): + # look for the opening square bracket + arr = line.split('[') + # if no opening bracket (leaf node), ignore this line + if len(arr) == 1: + continue + + # extract feature name from string between [] + fid = arr[1].split(']')[0].split('<')[0] + + if fid not in fmap: + # if the feature hasn't been seen yet + fmap[fid] = 1 + else: + fmap[fid] += 1 + + return fmap + + else: + trees = self.get_dump(fmap, with_stats=True) + + importance_type += '=' + fmap = {} + gmap = {} + for tree in trees: + for line in tree.split('\n'): + # look for the opening square bracket + arr = line.split('[') + # if no opening bracket (leaf node), ignore this line + if len(arr) == 1: + continue + + # look for the closing bracket, extract only info within that bracket + fid = arr[1].split(']') + + # extract gain or cover from string after closing bracket + g = float(fid[1].split(importance_type)[1].split(',')[0]) + + # extract feature name from string before closing bracket + fid = fid[0].split('<')[0] + + if fid not in fmap: + # if the feature hasn't been seen yet + fmap[fid] = 1 + gmap[fid] = g + else: + fmap[fid] += 1 + gmap[fid] += g + + # calculate average value (gain/cover) for each feature + for fid in gmap: + gmap[fid] = gmap[fid] / fmap[fid] + + return gmap def _validate_features(self, data): """ diff --git a/python-package/xgboost/plotting.py b/python-package/xgboost/plotting.py index 0a70799ad..94982b088 100644 --- a/python-package/xgboost/plotting.py +++ b/python-package/xgboost/plotting.py @@ -14,6 +14,7 @@ from .sklearn import XGBModel def plot_importance(booster, ax=None, height=0.2, xlim=None, ylim=None, title='Feature importance', xlabel='F score', ylabel='Features', + importance_type='weight', grid=True, **kwargs): """Plot importance based on fitted trees. @@ -24,6 +25,12 @@ def plot_importance(booster, ax=None, height=0.2, Booster or XGBModel instance, or dict taken by Booster.get_fscore() ax : matplotlib Axes, default None Target axes instance. If None, new figure and axes will be created. + importance_type : str, default "weight" + How the importance is calculated: either "weight", "gain", or "cover" + "weight" is the number of times a feature appears in a tree + "gain" is the average gain of splits which use the feature + "cover" is the average coverage of splits which use the feature + where coverage is defined as the number of samples affected by the split height : float, default 0.2 Bar height, passed to ax.barh() xlim : tuple, default None @@ -50,16 +57,16 @@ def plot_importance(booster, ax=None, height=0.2, raise ImportError('You must install matplotlib to plot importance') if isinstance(booster, XGBModel): - importance = booster.booster().get_fscore() + importance = booster.booster().get_score(importance_type=importance_type) elif isinstance(booster, Booster): - importance = booster.get_fscore() + importance = booster.get_score(importance_type=importance_type) elif isinstance(booster, dict): importance = booster else: raise ValueError('tree must be Booster, XGBModel or dict instance') if len(importance) == 0: - raise ValueError('Booster.get_fscore() results in empty') + raise ValueError('Booster.get_score() results in empty') tuples = [(k, importance[k]) for k in importance] tuples = sorted(tuples, key=lambda x: x[1]) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 141a79182..2fe4e2ee3 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -109,6 +109,7 @@ class XGBModel(XGBModelBase): hess: array_like of shape [n_samples] The value of the second derivative for each sample point """ + def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, silent=True, objective="reg:linear", nthread=-1, gamma=0, min_child_weight=1, max_delta_step=0, diff --git a/tests/python/test_basic.py b/tests/python/test_basic.py index 0ac401c3a..386462091 100644 --- a/tests/python/test_basic.py +++ b/tests/python/test_basic.py @@ -125,6 +125,35 @@ class TestBasic(unittest.TestCase): dm = xgb.DMatrix(dummy, feature_names=list('abcde')) self.assertRaises(ValueError, bst.predict, dm) + def test_feature_importances(self): + data = np.random.randn(100, 5) + target = np.array([0, 1] * 50) + + features = ['Feature1', 'Feature2', 'Feature3', 'Feature4', 'Feature5'] + + dm = xgb.DMatrix(data, label=target, + feature_names=features) + params = {'objective': 'multi:softprob', + 'eval_metric': 'mlogloss', + 'eta': 0.3, + 'num_class': 3} + + bst = xgb.train(params, dm, num_boost_round=10) + + # number of feature importances should == number of features + scores1 = bst.get_score() + scores2 = bst.get_score(importance_type='weight') + scores3 = bst.get_score(importance_type='cover') + scores4 = bst.get_score(importance_type='gain') + assert len(scores1) == len(features) + assert len(scores2) == len(features) + assert len(scores3) == len(features) + assert len(scores4) == len(features) + + # check backwards compatibility of get_fscore + fscores = bst.get_fscore() + assert scores1 == fscores + def test_load_file_invalid(self): self.assertRaises(xgb.core.XGBoostError, xgb.Booster, model_file='incorrect_path')