refactor tests to get rid of duplication (#4358)

* refactor tests to get rid of duplication

* address review comments
This commit is contained in:
Rong Ou
2019-04-12 00:21:48 -07:00
committed by Philip Hyunsu Cho
parent 3078b5944d
commit f4521bf6aa
5 changed files with 50 additions and 67 deletions

View File

@@ -33,13 +33,7 @@ TEST(gpu_predictor, Test) {
gpu_predictor->Init({}, {});
cpu_predictor->Init({}, {});
std::vector<std::unique_ptr<RegTree>> trees;
trees.push_back(std::unique_ptr<RegTree>(new RegTree()));
(*trees.back())[0].SetLeaf(1.5f);
(*trees.back()).Stat(0).sum_hess = 1.0f;
gbm::GBTreeModel model(0.5);
model.CommitModel(std::move(trees), 0);
model.param.num_output_group = 1;
gbm::GBTreeModel model = CreateTestModel();
int n_row = 5;
int n_col = 5;
@@ -181,13 +175,7 @@ TEST(gpu_predictor, MGPU_Test) {
int n_row = i, n_col = i;
auto dmat = CreateDMatrix(n_row, n_col, 0);
std::vector<std::unique_ptr<RegTree>> trees;
trees.push_back(std::unique_ptr<RegTree>(new RegTree()));
(*trees.back())[0].SetLeaf(1.5f);
(*trees.back()).Stat(0).sum_hess = 1.0f;
gbm::GBTreeModel model(0.5);
model.CommitModel(std::move(trees), 0);
model.param.num_output_group = 1;
gbm::GBTreeModel model = CreateTestModel();
// Test predict batch
HostDeviceVector<float> gpu_out_predictions;