Refactor tests with data generator. (#5439)
This commit is contained in:
@@ -9,26 +9,25 @@ namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
TEST(DenseColumn, Test) {
|
||||
auto dmat = CreateDMatrix(100, 10, 0.0);
|
||||
auto dmat = RandomDataGenerator(100, 10, 0.0).GenerateDMatix();
|
||||
GHistIndexMatrix gmat;
|
||||
gmat.Init((*dmat).get(), 256);
|
||||
gmat.Init(dmat.get(), 256);
|
||||
ColumnMatrix column_matrix;
|
||||
column_matrix.Init(gmat, 0.2);
|
||||
|
||||
for (auto i = 0ull; i < (*dmat)->Info().num_row_; i++) {
|
||||
for (auto j = 0ull; j < (*dmat)->Info().num_col_; j++) {
|
||||
for (auto i = 0ull; i < dmat->Info().num_row_; i++) {
|
||||
for (auto j = 0ull; j < dmat->Info().num_col_; j++) {
|
||||
auto col = column_matrix.GetColumn(j);
|
||||
ASSERT_EQ(gmat.index[i * (*dmat)->Info().num_col_ + j],
|
||||
ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j],
|
||||
col.GetGlobalBinIdx(i));
|
||||
}
|
||||
}
|
||||
delete dmat;
|
||||
}
|
||||
|
||||
TEST(SparseColumn, Test) {
|
||||
auto dmat = CreateDMatrix(100, 1, 0.85);
|
||||
auto dmat = RandomDataGenerator(100, 1, 0.85).GenerateDMatix();
|
||||
GHistIndexMatrix gmat;
|
||||
gmat.Init((*dmat).get(), 256);
|
||||
gmat.Init(dmat.get(), 256);
|
||||
ColumnMatrix column_matrix;
|
||||
column_matrix.Init(gmat, 0.5);
|
||||
auto col = column_matrix.GetColumn(0);
|
||||
@@ -37,13 +36,12 @@ TEST(SparseColumn, Test) {
|
||||
ASSERT_EQ(gmat.index[gmat.row_ptr[col.GetRowIdx(i)]],
|
||||
col.GetGlobalBinIdx(i));
|
||||
}
|
||||
delete dmat;
|
||||
}
|
||||
|
||||
TEST(DenseColumnWithMissing, Test) {
|
||||
auto dmat = CreateDMatrix(100, 1, 0.5);
|
||||
auto dmat = RandomDataGenerator(100, 1, 0.5).GenerateDMatix();
|
||||
GHistIndexMatrix gmat;
|
||||
gmat.Init((*dmat).get(), 256);
|
||||
gmat.Init(dmat.get(), 256);
|
||||
ColumnMatrix column_matrix;
|
||||
column_matrix.Init(gmat, 0.2);
|
||||
auto col = column_matrix.GetColumn(0);
|
||||
@@ -52,7 +50,6 @@ TEST(DenseColumnWithMissing, Test) {
|
||||
EXPECT_EQ(gmat.index[gmat.row_ptr[col.GetRowIdx(i)]],
|
||||
col.GetGlobalBinIdx(i));
|
||||
}
|
||||
delete dmat;
|
||||
}
|
||||
|
||||
void TestGHistIndexMatrixCreation(size_t nthreads) {
|
||||
|
||||
@@ -128,8 +128,7 @@ TEST(CutsBuilder, SearchGroupInd) {
|
||||
size_t constexpr kRows = 17;
|
||||
size_t constexpr kCols = 15;
|
||||
|
||||
auto pp_dmat = CreateDMatrix(kRows, kCols, 0);
|
||||
std::shared_ptr<DMatrix> p_mat {*pp_dmat};
|
||||
auto p_mat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatix();
|
||||
|
||||
std::vector<bst_int> group(kNumGroups);
|
||||
group[0] = 2;
|
||||
@@ -149,8 +148,6 @@ TEST(CutsBuilder, SearchGroupInd) {
|
||||
ASSERT_EQ(group_ind, 2);
|
||||
|
||||
EXPECT_ANY_THROW(CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 17));
|
||||
|
||||
delete pp_dmat;
|
||||
}
|
||||
|
||||
TEST(SparseCuts, SingleThreadedBuild) {
|
||||
@@ -158,8 +155,7 @@ TEST(SparseCuts, SingleThreadedBuild) {
|
||||
size_t constexpr kCols = 31;
|
||||
size_t constexpr kBins = 256;
|
||||
|
||||
auto pp_dmat = CreateDMatrix(kRows, kCols, 0);
|
||||
std::shared_ptr<DMatrix> p_fmat {*pp_dmat};
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatix();
|
||||
|
||||
common::GHistIndexMatrix hmat;
|
||||
hmat.Init(p_fmat.get(), kBins);
|
||||
@@ -173,8 +169,6 @@ TEST(SparseCuts, SingleThreadedBuild) {
|
||||
ASSERT_EQ(hmat.cut.Ptrs(), cuts.Ptrs());
|
||||
ASSERT_EQ(hmat.cut.Values(), cuts.Values());
|
||||
ASSERT_EQ(hmat.cut.MinValues(), cuts.MinValues());
|
||||
|
||||
delete pp_dmat;
|
||||
}
|
||||
|
||||
TEST(SparseCuts, MultiThreadedBuild) {
|
||||
@@ -212,17 +206,13 @@ TEST(SparseCuts, MultiThreadedBuild) {
|
||||
};
|
||||
|
||||
{
|
||||
auto pp_mat = CreateDMatrix(kRows, kCols, 0);
|
||||
DMatrix* p_fmat = (*pp_mat).get();
|
||||
Compare(p_fmat);
|
||||
delete pp_mat;
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatix();
|
||||
Compare(p_fmat.get());
|
||||
}
|
||||
|
||||
{
|
||||
auto pp_mat = CreateDMatrix(kRows, kCols, 0.0001);
|
||||
DMatrix* p_fmat = (*pp_mat).get();
|
||||
Compare(p_fmat);
|
||||
delete pp_mat;
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0.0001).GenerateDMatix();
|
||||
Compare(p_fmat.get());
|
||||
}
|
||||
|
||||
omp_set_num_threads(ori_nthreads);
|
||||
|
||||
@@ -128,7 +128,7 @@ inline void TestRank(const std::vector<float>& cuts,
|
||||
// Ignore the last cut, its special
|
||||
double sum_weight = 0.0;
|
||||
size_t j = 0;
|
||||
for (auto i = 0; i < cuts.size() - 1; i++) {
|
||||
for (size_t i = 0; i < cuts.size() - 1; i++) {
|
||||
while (cuts[i] > sorted_x[j]) {
|
||||
sum_weight += sorted_weights[j];
|
||||
j++;
|
||||
@@ -142,7 +142,7 @@ inline void TestRank(const std::vector<float>& cuts,
|
||||
inline void ValidateColumn(const HistogramCuts& cuts, int column_idx,
|
||||
const std::vector<float>& sorted_column,
|
||||
const std::vector<float>& sorted_weights,
|
||||
int num_bins) {
|
||||
size_t num_bins) {
|
||||
|
||||
// Check the endpoints are correct
|
||||
EXPECT_LT(cuts.MinValues()[column_idx], sorted_column.front());
|
||||
|
||||
Reference in New Issue
Block a user