/*! * Copyright 2018 by Contributors */ #include "../helpers.h" #include "../../../src/common/host_device_vector.h" #include #include #include #include #include namespace xgboost { namespace tree { TEST(Updater, Prune) { int constexpr n_rows = 32, n_cols = 16; std::vector> cfg; cfg.push_back(std::pair( "num_feature", std::to_string(n_cols))); cfg.push_back(std::pair( "min_split_loss", "10")); cfg.push_back(std::pair( "silent", "1")); // These data are just place holders. HostDeviceVector gpair = { {0.50f, 0.25f}, {0.50f, 0.25f}, {0.50f, 0.25f}, {0.50f, 0.25f}, {0.25f, 0.24f}, {0.25f, 0.24f}, {0.25f, 0.24f}, {0.25f, 0.24f} }; auto dmat = CreateDMatrix(32, 16, 0.4, 3); // prepare tree RegTree tree = RegTree(); tree.InitModel(); tree.param.InitAllowUnknown(cfg); std::vector trees {&tree}; // prepare pruner std::unique_ptr pruner(TreeUpdater::Create("prune")); pruner->Init(cfg); // loss_chg < min_split_loss; tree.AddChilds(0); int cleft = tree[0].LeftChild(); int cright = tree[0].RightChild(); tree[cleft].SetLeaf(0.3f, 0); tree[cright].SetLeaf(0.4f, 0); pruner->Update(&gpair, dmat->get(), trees); ASSERT_EQ(tree.NumExtraNodes(), 0); // loss_chg > min_split_loss; tree.AddChilds(0); cleft = tree[0].LeftChild(); cright = tree[0].RightChild(); tree[cleft].SetLeaf(0.3f, 0); tree[cright].SetLeaf(0.4f, 0); tree.Stat(0).loss_chg = 11; pruner->Update(&gpair, dmat->get(), trees); ASSERT_EQ(tree.NumExtraNodes(), 2); // loss_chg == min_split_loss; tree.Stat(0).loss_chg = 10; pruner->Update(&gpair, dmat->get(), trees); ASSERT_EQ(tree.NumExtraNodes(), 2); delete dmat; } } // namespace tree } // namespace xgboost