Requires setting leaf stat when expanding tree. (#5501)

* Fix GPU Hist feature importance.
This commit is contained in:
Jiaming Yuan
2020-04-10 12:27:03 +08:00
committed by GitHub
parent dc2950fd90
commit 7d52c0b8c2
11 changed files with 179 additions and 50 deletions

View File

@@ -88,13 +88,13 @@ TEST(Tree, Load) {
TEST(Tree, AllocateNode) {
RegTree tree;
tree.ExpandNode(
0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
tree.CollapseToLeaf(0, 0);
ASSERT_EQ(tree.NumExtraNodes(), 0);
tree.ExpandNode(
0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
ASSERT_EQ(tree.NumExtraNodes(), 2);
auto& nodes = tree.GetNodes();
@@ -107,18 +107,18 @@ RegTree ConstructTree() {
RegTree tree;
tree.ExpandNode(
/*nid=*/0, /*split_index=*/0, /*split_value=*/0.0f,
/*default_left=*/true,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
/*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f,
/*right_sum=*/0.0f);
auto left = tree[0].LeftChild();
auto right = tree[0].RightChild();
tree.ExpandNode(
/*nid=*/left, /*split_index=*/1, /*split_value=*/1.0f,
/*default_left=*/false,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
/*default_left=*/false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f,
/*right_sum=*/0.0f);
tree.ExpandNode(
/*nid=*/right, /*split_index=*/2, /*split_value=*/2.0f,
/*default_left=*/false,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
/*default_left=*/false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f,
/*right_sum=*/0.0f);
return tree;
}
@@ -222,7 +222,8 @@ TEST(Tree, DumpDot) {
TEST(Tree, JsonIO) {
RegTree tree;
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
Json j_tree{Object()};
tree.SaveModel(&j_tree);
@@ -246,8 +247,10 @@ TEST(Tree, JsonIO) {
auto left = tree[0].LeftChild();
auto right = tree[0].RightChild();
tree.ExpandNode(left, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
tree.ExpandNode(right, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
tree.ExpandNode(left, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
tree.ExpandNode(right, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
tree.SaveModel(&j_tree);
tree.ChangeToLeaf(1, 1.0f);