/*! * Copyright 2018 by Contributors */ #include "../helpers.h" #include "../../../src/common/host_device_vector.h" #include #include #include #include #include namespace xgboost { namespace tree { TEST(Updater, Prune) { int constexpr kNRows = 32, kNCols = 16; std::vector> cfg; cfg.emplace_back(std::pair( "num_feature", std::to_string(kNCols))); cfg.emplace_back(std::pair( "min_split_loss", "10")); cfg.emplace_back(std::pair( "silent", "1")); // These data are just place holders. HostDeviceVector gpair = { {0.50f, 0.25f}, {0.50f, 0.25f}, {0.50f, 0.25f}, {0.50f, 0.25f}, {0.25f, 0.24f}, {0.25f, 0.24f}, {0.25f, 0.24f}, {0.25f, 0.24f} }; auto dmat = CreateDMatrix(32, 16, 0.4, 3); // prepare tree RegTree tree = RegTree(); tree.param.InitAllowUnknown(cfg); std::vector trees {&tree}; // prepare pruner std::unique_ptr pruner(TreeUpdater::Create("prune")); pruner->Init(cfg); // loss_chg < min_split_loss; tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 0.0f, 0.0f); pruner->Update(&gpair, dmat->get(), trees); ASSERT_EQ(tree.NumExtraNodes(), 0); // loss_chg > min_split_loss; tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 11.0f, 0.0f); pruner->Update(&gpair, dmat->get(), trees); ASSERT_EQ(tree.NumExtraNodes(), 2); // loss_chg == min_split_loss; tree.Stat(0).loss_chg = 10; pruner->Update(&gpair, dmat->get(), trees); ASSERT_EQ(tree.NumExtraNodes(), 2); delete dmat; } } // namespace tree } // namespace xgboost