[multi] Implement weight feature importance. (#10700)
This commit is contained in:
@@ -236,12 +236,11 @@ class GBTree : public GradientBooster {
|
||||
auto add_score = [&](auto fn) {
|
||||
for (auto idx : trees) {
|
||||
CHECK_LE(idx, total_n_trees) << "Invalid tree index.";
|
||||
auto const& p_tree = model_.trees[idx];
|
||||
p_tree->WalkTree([&](bst_node_t nidx) {
|
||||
auto const& node = (*p_tree)[nidx];
|
||||
if (!node.IsLeaf()) {
|
||||
split_counts[node.SplitIndex()]++;
|
||||
fn(p_tree, nidx, node.SplitIndex());
|
||||
auto const& tree = *model_.trees[idx];
|
||||
tree.WalkTree([&](bst_node_t nidx) {
|
||||
if (!tree.IsLeaf(nidx)) {
|
||||
split_counts[tree.SplitIndex(nidx)]++;
|
||||
fn(tree, nidx, tree.SplitIndex(nidx));
|
||||
}
|
||||
return true;
|
||||
});
|
||||
@@ -253,12 +252,18 @@ class GBTree : public GradientBooster {
|
||||
gain_map[split] = split_counts[split];
|
||||
});
|
||||
} else if (importance_type == "gain" || importance_type == "total_gain") {
|
||||
add_score([&](auto const &p_tree, bst_node_t nidx, bst_feature_t split) {
|
||||
gain_map[split] += p_tree->Stat(nidx).loss_chg;
|
||||
if (!model_.trees.empty() && model_.trees.front()->IsMultiTarget()) {
|
||||
LOG(FATAL) << "gain/total_gain " << MTNotImplemented();
|
||||
}
|
||||
add_score([&](auto const& tree, bst_node_t nidx, bst_feature_t split) {
|
||||
gain_map[split] += tree.Stat(nidx).loss_chg;
|
||||
});
|
||||
} else if (importance_type == "cover" || importance_type == "total_cover") {
|
||||
add_score([&](auto const &p_tree, bst_node_t nidx, bst_feature_t split) {
|
||||
gain_map[split] += p_tree->Stat(nidx).sum_hess;
|
||||
if (!model_.trees.empty() && model_.trees.front()->IsMultiTarget()) {
|
||||
LOG(FATAL) << "cover/total_cover " << MTNotImplemented();
|
||||
}
|
||||
add_score([&](auto const& tree, bst_node_t nidx, bst_feature_t split) {
|
||||
gain_map[split] += tree.Stat(nidx).sum_hess;
|
||||
});
|
||||
} else {
|
||||
LOG(FATAL)
|
||||
|
||||
Reference in New Issue
Block a user