Require leaf statistics when expanding tree (#4015)
* Cache left and right gradient sums * Require leaf statistics when expanding tree
This commit is contained in:
committed by
Philip Hyunsu Cho
parent
0f8af85f64
commit
1fc37e4749
@@ -82,12 +82,15 @@ TEST(Param, SplitEntry) {
|
||||
|
||||
xgboost::tree::SplitEntry se2;
|
||||
EXPECT_FALSE(se1.Update(se2));
|
||||
EXPECT_FALSE(se2.Update(-1, 100, 0, true));
|
||||
ASSERT_TRUE(se2.Update(1, 100, 0, true));
|
||||
EXPECT_FALSE(se2.Update(-1, 100, 0, true, xgboost::tree::GradStats(),
|
||||
xgboost::tree::GradStats()));
|
||||
ASSERT_TRUE(se2.Update(1, 100, 0, true, xgboost::tree::GradStats(),
|
||||
xgboost::tree::GradStats()));
|
||||
ASSERT_TRUE(se1.Update(se2));
|
||||
|
||||
xgboost::tree::SplitEntry se3;
|
||||
se3.Update(2, 101, 0, false);
|
||||
se3.Update(2, 101, 0, false, xgboost::tree::GradStats(),
|
||||
xgboost::tree::GradStats());
|
||||
xgboost::tree::SplitEntry::Reduce(se2, se3);
|
||||
EXPECT_EQ(se2.SplitIndex(), 101);
|
||||
EXPECT_FALSE(se2.DefaultLeft());
|
||||
|
||||
@@ -38,22 +38,13 @@ TEST(Updater, Prune) {
|
||||
pruner->Init(cfg);
|
||||
|
||||
// loss_chg < min_split_loss;
|
||||
tree.ExpandNode(0, 0, 0, true);
|
||||
int cleft = tree[0].LeftChild();
|
||||
int cright = tree[0].RightChild();
|
||||
tree[cleft].SetLeaf(0.3f, 0);
|
||||
tree[cright].SetLeaf(0.4f, 0);
|
||||
tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 0.0f, 0.0f);
|
||||
pruner->Update(&gpair, dmat->get(), trees);
|
||||
|
||||
ASSERT_EQ(tree.NumExtraNodes(), 0);
|
||||
|
||||
// loss_chg > min_split_loss;
|
||||
tree.ExpandNode(0, 0, 0, true);
|
||||
cleft = tree[0].LeftChild();
|
||||
cright = tree[0].RightChild();
|
||||
tree[cleft].SetLeaf(0.3f, 0);
|
||||
tree[cright].SetLeaf(0.4f, 0);
|
||||
tree.Stat(0).loss_chg = 11;
|
||||
tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 11.0f, 0.0f);
|
||||
pruner->Update(&gpair, dmat->get(), trees);
|
||||
|
||||
ASSERT_EQ(tree.NumExtraNodes(), 2);
|
||||
|
||||
@@ -29,12 +29,9 @@ TEST(Updater, Refresh) {
|
||||
std::vector<RegTree*> trees {&tree};
|
||||
std::unique_ptr<TreeUpdater> refresher(TreeUpdater::Create("refresh"));
|
||||
|
||||
tree.ExpandNode(0, 0, 0, true);
|
||||
tree.ExpandNode(0, 2, 0.2f, false, 0.0, 0.2f, 0.8f, 0.0f, 0.0f);
|
||||
int cleft = tree[0].LeftChild();
|
||||
int cright = tree[0].RightChild();
|
||||
tree[cleft].SetLeaf(0.2f, 0);
|
||||
tree[cright].SetLeaf(0.8f, 0);
|
||||
tree[0].SetSplit(2, 0.2f);
|
||||
|
||||
tree.Stat(cleft).base_weight = 1.2;
|
||||
tree.Stat(cright).base_weight = 1.3;
|
||||
|
||||
Reference in New Issue
Block a user