Requires setting leaf stat when expanding tree. (#5501)
* Fix GPU Hist feature importance.
This commit is contained in:
@@ -88,13 +88,13 @@ TEST(Tree, Load) {
|
||||
|
||||
TEST(Tree, AllocateNode) {
|
||||
RegTree tree;
|
||||
tree.ExpandNode(
|
||||
0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
||||
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
|
||||
tree.CollapseToLeaf(0, 0);
|
||||
ASSERT_EQ(tree.NumExtraNodes(), 0);
|
||||
|
||||
tree.ExpandNode(
|
||||
0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
||||
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
|
||||
ASSERT_EQ(tree.NumExtraNodes(), 2);
|
||||
|
||||
auto& nodes = tree.GetNodes();
|
||||
@@ -107,18 +107,18 @@ RegTree ConstructTree() {
|
||||
RegTree tree;
|
||||
tree.ExpandNode(
|
||||
/*nid=*/0, /*split_index=*/0, /*split_value=*/0.0f,
|
||||
/*default_left=*/true,
|
||||
0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||
/*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f,
|
||||
/*right_sum=*/0.0f);
|
||||
auto left = tree[0].LeftChild();
|
||||
auto right = tree[0].RightChild();
|
||||
tree.ExpandNode(
|
||||
/*nid=*/left, /*split_index=*/1, /*split_value=*/1.0f,
|
||||
/*default_left=*/false,
|
||||
0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||
/*default_left=*/false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f,
|
||||
/*right_sum=*/0.0f);
|
||||
tree.ExpandNode(
|
||||
/*nid=*/right, /*split_index=*/2, /*split_value=*/2.0f,
|
||||
/*default_left=*/false,
|
||||
0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||
/*default_left=*/false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, /*left_sum=*/0.0f,
|
||||
/*right_sum=*/0.0f);
|
||||
return tree;
|
||||
}
|
||||
|
||||
@@ -222,7 +222,8 @@ TEST(Tree, DumpDot) {
|
||||
|
||||
TEST(Tree, JsonIO) {
|
||||
RegTree tree;
|
||||
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||
tree.ExpandNode(0, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
||||
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
|
||||
Json j_tree{Object()};
|
||||
tree.SaveModel(&j_tree);
|
||||
|
||||
@@ -246,8 +247,10 @@ TEST(Tree, JsonIO) {
|
||||
|
||||
auto left = tree[0].LeftChild();
|
||||
auto right = tree[0].RightChild();
|
||||
tree.ExpandNode(left, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||
tree.ExpandNode(right, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||
tree.ExpandNode(left, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
||||
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
|
||||
tree.ExpandNode(right, 0, 0.0f, false, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
|
||||
/*left_sum=*/0.0f, /*right_sum=*/0.0f);
|
||||
tree.SaveModel(&j_tree);
|
||||
|
||||
tree.ChangeToLeaf(1, 1.0f);
|
||||
|
||||
Reference in New Issue
Block a user