Refactor tests with data generator. (#5439)
This commit is contained in:
@@ -15,16 +15,18 @@ namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
TEST(Updater, Refresh) {
|
||||
int constexpr kNRows = 8, kNCols = 16;
|
||||
bst_row_t constexpr kRows = 8;
|
||||
bst_feature_t constexpr kCols = 16;
|
||||
|
||||
HostDeviceVector<GradientPair> gpair =
|
||||
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
|
||||
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
|
||||
auto dmat = CreateDMatrix(kNRows, kNCols, 0.4, 3);
|
||||
std::vector<std::pair<std::string, std::string>> cfg {
|
||||
{"reg_alpha", "0.0"},
|
||||
{"num_feature", std::to_string(kNCols)},
|
||||
{"reg_lambda", "1"}};
|
||||
std::shared_ptr<DMatrix> p_dmat{
|
||||
RandomDataGenerator{kRows, kCols, 0.4f}.Seed(3).GenerateDMatix()};
|
||||
std::vector<std::pair<std::string, std::string>> cfg{
|
||||
{"reg_alpha", "0.0"},
|
||||
{"num_feature", std::to_string(kCols)},
|
||||
{"reg_lambda", "1"}};
|
||||
|
||||
RegTree tree = RegTree();
|
||||
auto lparam = CreateEmptyGenericParam(GPUIDX);
|
||||
@@ -40,7 +42,7 @@ TEST(Updater, Refresh) {
|
||||
tree.Stat(cright).base_weight = 1.3;
|
||||
|
||||
refresher->Configure(cfg);
|
||||
refresher->Update(&gpair, dmat->get(), trees);
|
||||
refresher->Update(&gpair, p_dmat.get(), trees);
|
||||
|
||||
bst_float constexpr kEps = 1e-6;
|
||||
ASSERT_NEAR(-0.183392, tree[cright].LeafValue(), kEps);
|
||||
@@ -48,8 +50,6 @@ TEST(Updater, Refresh) {
|
||||
ASSERT_NEAR(0, tree.Stat(cleft).loss_chg, kEps);
|
||||
ASSERT_NEAR(0, tree.Stat(1).loss_chg, kEps);
|
||||
ASSERT_NEAR(0, tree.Stat(2).loss_chg, kEps);
|
||||
|
||||
delete dmat;
|
||||
}
|
||||
|
||||
} // namespace tree
|
||||
|
||||
Reference in New Issue
Block a user