refactor tests to get rid of duplication (#4358)

* refactor tests to get rid of duplication

* address review comments
This commit is contained in:
Rong Ou
2019-04-12 00:21:48 -07:00
committed by Philip Hyunsu Cho
parent 3078b5944d
commit f4521bf6aa
5 changed files with 50 additions and 67 deletions

View File

@@ -5,6 +5,7 @@
#include "xgboost/c_api.h"
#include <random>
#include <cinttypes>
#include <dmlc/filesystem.h>
bool FileExists(const std::string& filename) {
struct stat st;
@@ -142,4 +143,38 @@ std::shared_ptr<xgboost::DMatrix>* CreateDMatrix(int rows, int columns,
return static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
}
std::unique_ptr<DMatrix> CreateSparsePageDMatrix() {
// Create sufficiently large data to make two row pages
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/big.libsvm";
CreateBigTestData(tmp_file, 12);
std::unique_ptr<DMatrix> dmat = std::unique_ptr<DMatrix>(DMatrix::Load(
tmp_file + "#" + tmp_file + ".cache", true, false, "auto", 64UL));
EXPECT_TRUE(FileExists(tmp_file + ".cache.row.page"));
// Loop over the batches and count the records
int64_t batch_count = 0;
int64_t row_count = 0;
for (const auto &batch : dmat->GetRowBatches()) {
batch_count++;
row_count += batch.Size();
}
EXPECT_EQ(batch_count, 2);
EXPECT_EQ(row_count, dmat->Info().num_row_);
return dmat;
}
gbm::GBTreeModel CreateTestModel() {
std::vector<std::unique_ptr<RegTree>> trees;
trees.push_back(std::unique_ptr<RegTree>(new RegTree));
(*trees.back())[0].SetLeaf(1.5f);
(*trees.back()).Stat(0).sum_hess = 1.0f;
gbm::GBTreeModel model(0.5);
model.CommitModel(std::move(trees), 0);
model.param.num_output_group = 1;
model.base_margin = 0;
return model;
}
} // namespace xgboost