[multi] Implement weight feature importance. (#10700)
This commit is contained in:
parent
402e7837fb
commit
9b88495840
@ -236,12 +236,11 @@ class GBTree : public GradientBooster {
|
|||||||
auto add_score = [&](auto fn) {
|
auto add_score = [&](auto fn) {
|
||||||
for (auto idx : trees) {
|
for (auto idx : trees) {
|
||||||
CHECK_LE(idx, total_n_trees) << "Invalid tree index.";
|
CHECK_LE(idx, total_n_trees) << "Invalid tree index.";
|
||||||
auto const& p_tree = model_.trees[idx];
|
auto const& tree = *model_.trees[idx];
|
||||||
p_tree->WalkTree([&](bst_node_t nidx) {
|
tree.WalkTree([&](bst_node_t nidx) {
|
||||||
auto const& node = (*p_tree)[nidx];
|
if (!tree.IsLeaf(nidx)) {
|
||||||
if (!node.IsLeaf()) {
|
split_counts[tree.SplitIndex(nidx)]++;
|
||||||
split_counts[node.SplitIndex()]++;
|
fn(tree, nidx, tree.SplitIndex(nidx));
|
||||||
fn(p_tree, nidx, node.SplitIndex());
|
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
});
|
});
|
||||||
@ -253,12 +252,18 @@ class GBTree : public GradientBooster {
|
|||||||
gain_map[split] = split_counts[split];
|
gain_map[split] = split_counts[split];
|
||||||
});
|
});
|
||||||
} else if (importance_type == "gain" || importance_type == "total_gain") {
|
} else if (importance_type == "gain" || importance_type == "total_gain") {
|
||||||
add_score([&](auto const &p_tree, bst_node_t nidx, bst_feature_t split) {
|
if (!model_.trees.empty() && model_.trees.front()->IsMultiTarget()) {
|
||||||
gain_map[split] += p_tree->Stat(nidx).loss_chg;
|
LOG(FATAL) << "gain/total_gain " << MTNotImplemented();
|
||||||
|
}
|
||||||
|
add_score([&](auto const& tree, bst_node_t nidx, bst_feature_t split) {
|
||||||
|
gain_map[split] += tree.Stat(nidx).loss_chg;
|
||||||
});
|
});
|
||||||
} else if (importance_type == "cover" || importance_type == "total_cover") {
|
} else if (importance_type == "cover" || importance_type == "total_cover") {
|
||||||
add_score([&](auto const &p_tree, bst_node_t nidx, bst_feature_t split) {
|
if (!model_.trees.empty() && model_.trees.front()->IsMultiTarget()) {
|
||||||
gain_map[split] += p_tree->Stat(nidx).sum_hess;
|
LOG(FATAL) << "cover/total_cover " << MTNotImplemented();
|
||||||
|
}
|
||||||
|
add_score([&](auto const& tree, bst_node_t nidx, bst_feature_t split) {
|
||||||
|
gain_map[split] += tree.Stat(nidx).sum_hess;
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
LOG(FATAL)
|
LOG(FATAL)
|
||||||
|
|||||||
@ -336,6 +336,36 @@ def test_feature_importances_weight():
|
|||||||
cls.feature_importances_
|
cls.feature_importances_
|
||||||
|
|
||||||
|
|
||||||
|
def test_feature_importances_weight_vector_leaf() -> None:
|
||||||
|
from sklearn.datasets import make_multilabel_classification
|
||||||
|
|
||||||
|
X, y = make_multilabel_classification(random_state=1994)
|
||||||
|
with pytest.raises(ValueError, match="gain/total_gain"):
|
||||||
|
clf = xgb.XGBClassifier(multi_strategy="multi_output_tree")
|
||||||
|
clf.fit(X, y)
|
||||||
|
clf.feature_importances_
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="cover/total_cover"):
|
||||||
|
clf = xgb.XGBClassifier(
|
||||||
|
multi_strategy="multi_output_tree", importance_type="cover"
|
||||||
|
)
|
||||||
|
clf.fit(X, y)
|
||||||
|
clf.feature_importances_
|
||||||
|
|
||||||
|
clf = xgb.XGBClassifier(
|
||||||
|
multi_strategy="multi_output_tree",
|
||||||
|
importance_type="weight",
|
||||||
|
colsample_bynode=0.2,
|
||||||
|
)
|
||||||
|
clf.fit(X, y, feature_weights=np.arange(0, X.shape[1]))
|
||||||
|
fi = clf.feature_importances_
|
||||||
|
assert fi[0] == 0.0
|
||||||
|
assert fi[-1] > fi[1] * 5
|
||||||
|
|
||||||
|
w = np.polynomial.Polynomial.fit(np.arange(0, X.shape[1]), fi, deg=1)
|
||||||
|
assert w.coef[1] > 0.03
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_pandas())
|
@pytest.mark.skipif(**tm.no_pandas())
|
||||||
def test_feature_importances_gain():
|
def test_feature_importances_gain():
|
||||||
from sklearn.datasets import load_digits
|
from sklearn.datasets import load_digits
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user