Refactor tests with data generator. (#5439)

This commit is contained in:
Jiaming Yuan
2020-03-27 06:44:44 +08:00
committed by GitHub
parent 7146b91d5a
commit 4942da64ae
26 changed files with 334 additions and 259 deletions

View File

@@ -15,16 +15,18 @@ namespace xgboost {
namespace tree {
TEST(Updater, Refresh) {
int constexpr kNRows = 8, kNCols = 16;
bst_row_t constexpr kRows = 8;
bst_feature_t constexpr kCols = 16;
HostDeviceVector<GradientPair> gpair =
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
auto dmat = CreateDMatrix(kNRows, kNCols, 0.4, 3);
std::vector<std::pair<std::string, std::string>> cfg {
{"reg_alpha", "0.0"},
{"num_feature", std::to_string(kNCols)},
{"reg_lambda", "1"}};
std::shared_ptr<DMatrix> p_dmat{
RandomDataGenerator{kRows, kCols, 0.4f}.Seed(3).GenerateDMatix()};
std::vector<std::pair<std::string, std::string>> cfg{
{"reg_alpha", "0.0"},
{"num_feature", std::to_string(kCols)},
{"reg_lambda", "1"}};
RegTree tree = RegTree();
auto lparam = CreateEmptyGenericParam(GPUIDX);
@@ -40,7 +42,7 @@ TEST(Updater, Refresh) {
tree.Stat(cright).base_weight = 1.3;
refresher->Configure(cfg);
refresher->Update(&gpair, dmat->get(), trees);
refresher->Update(&gpair, p_dmat.get(), trees);
bst_float constexpr kEps = 1e-6;
ASSERT_NEAR(-0.183392, tree[cright].LeafValue(), kEps);
@@ -48,8 +50,6 @@ TEST(Updater, Refresh) {
ASSERT_NEAR(0, tree.Stat(cleft).loss_chg, kEps);
ASSERT_NEAR(0, tree.Stat(1).loss_chg, kEps);
ASSERT_NEAR(0, tree.Stat(2).loss_chg, kEps);
delete dmat;
}
} // namespace tree