refactor tests to get rid of duplication (#4358)
* refactor tests to get rid of duplication * address review comments
This commit is contained in:
committed by
Philip Hyunsu Cho
parent
3078b5944d
commit
f4521bf6aa
@@ -33,13 +33,7 @@ TEST(gpu_predictor, Test) {
|
||||
gpu_predictor->Init({}, {});
|
||||
cpu_predictor->Init({}, {});
|
||||
|
||||
std::vector<std::unique_ptr<RegTree>> trees;
|
||||
trees.push_back(std::unique_ptr<RegTree>(new RegTree()));
|
||||
(*trees.back())[0].SetLeaf(1.5f);
|
||||
(*trees.back()).Stat(0).sum_hess = 1.0f;
|
||||
gbm::GBTreeModel model(0.5);
|
||||
model.CommitModel(std::move(trees), 0);
|
||||
model.param.num_output_group = 1;
|
||||
gbm::GBTreeModel model = CreateTestModel();
|
||||
|
||||
int n_row = 5;
|
||||
int n_col = 5;
|
||||
@@ -181,13 +175,7 @@ TEST(gpu_predictor, MGPU_Test) {
|
||||
int n_row = i, n_col = i;
|
||||
auto dmat = CreateDMatrix(n_row, n_col, 0);
|
||||
|
||||
std::vector<std::unique_ptr<RegTree>> trees;
|
||||
trees.push_back(std::unique_ptr<RegTree>(new RegTree()));
|
||||
(*trees.back())[0].SetLeaf(1.5f);
|
||||
(*trees.back()).Stat(0).sum_hess = 1.0f;
|
||||
gbm::GBTreeModel model(0.5);
|
||||
model.CommitModel(std::move(trees), 0);
|
||||
model.param.num_output_group = 1;
|
||||
gbm::GBTreeModel model = CreateTestModel();
|
||||
|
||||
// Test predict batch
|
||||
HostDeviceVector<float> gpu_out_predictions;
|
||||
|
||||
Reference in New Issue
Block a user