Fix divide by 0 in feature importance when no split is found. (#6676)
This commit is contained in:
parent
72892cc80d
commit
a4101de678
@ -920,7 +920,10 @@ class XGBModel(XGBModelBase):
|
||||
score = b.get_score(importance_type=self.importance_type)
|
||||
all_features = [score.get(f, 0.) for f in b.feature_names]
|
||||
all_features = np.array(all_features, dtype=np.float32)
|
||||
return all_features / all_features.sum()
|
||||
total = all_features.sum()
|
||||
if total == 0:
|
||||
return all_features
|
||||
return all_features / total
|
||||
|
||||
@property
|
||||
def coef_(self):
|
||||
|
||||
@ -252,7 +252,9 @@ def test_feature_importances_gain():
|
||||
xgb_model = xgb.XGBClassifier(
|
||||
random_state=0, tree_method="exact",
|
||||
learning_rate=0.1,
|
||||
importance_type="gain").fit(X, y)
|
||||
importance_type="gain",
|
||||
use_label_encoder=False,
|
||||
).fit(X, y)
|
||||
|
||||
exp = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
||||
0.00326159, 0., 0., 0., 0., 0., 0., 0., 0.,
|
||||
@ -270,17 +272,30 @@ def test_feature_importances_gain():
|
||||
y = pd.Series(digits['target'])
|
||||
X = pd.DataFrame(digits['data'])
|
||||
xgb_model = xgb.XGBClassifier(
|
||||
random_state=0, tree_method="exact",
|
||||
random_state=0,
|
||||
tree_method="exact",
|
||||
learning_rate=0.1,
|
||||
importance_type="gain").fit(X, y)
|
||||
importance_type="gain",
|
||||
use_label_encoder=False,
|
||||
).fit(X, y)
|
||||
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
|
||||
|
||||
xgb_model = xgb.XGBClassifier(
|
||||
random_state=0, tree_method="exact",
|
||||
random_state=0,
|
||||
tree_method="exact",
|
||||
learning_rate=0.1,
|
||||
importance_type="gain").fit(X, y)
|
||||
importance_type="gain",
|
||||
use_label_encoder=False,
|
||||
).fit(X, y)
|
||||
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
|
||||
|
||||
# no split can be found
|
||||
cls = xgb.XGBClassifier(
|
||||
min_child_weight=1000, tree_method="hist", n_estimators=1, use_label_encoder=False
|
||||
)
|
||||
cls.fit(X, y)
|
||||
assert np.all(cls.feature_importances_ == 0)
|
||||
|
||||
|
||||
def test_select_feature():
|
||||
from sklearn.datasets import load_digits
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user