Use adapter to initialize column matrix. (#7912)

This commit is contained in:
Jiaming Yuan
2022-05-18 16:15:12 +08:00
committed by GitHub
parent 5ef33adf68
commit 19775ffe15
4 changed files with 82 additions and 73 deletions

View File

@@ -31,34 +31,33 @@ TEST(DenseColumn, Test) {
ASSERT_FALSE(column_matrix.AnyMissing());
for (auto i = 0ull; i < dmat->Info().num_row_; i++) {
for (auto j = 0ull; j < dmat->Info().num_col_; j++) {
switch (column_matrix.GetTypeSize()) {
case kUint8BinsTypeSize: {
auto col = column_matrix.DenseColumn<uint8_t, false>(j);
ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j], col.GetGlobalBinIdx(i));
} break;
case kUint16BinsTypeSize: {
auto col = column_matrix.DenseColumn<uint16_t, false>(j);
ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j], col.GetGlobalBinIdx(i));
} break;
case kUint32BinsTypeSize: {
auto col = column_matrix.DenseColumn<uint32_t, false>(j);
ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j], col.GetGlobalBinIdx(i));
} break;
}
DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) {
using T = decltype(dtype);
auto col = column_matrix.DenseColumn<T, false>(j);
ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j], col.GetGlobalBinIdx(i));
});
}
}
}
}
template <typename BinIdxType>
inline void CheckSparseColumn(const SparseColumnIter<BinIdxType>& col_input,
const GHistIndexMatrix& gmat) {
const SparseColumnIter<BinIdxType>& col =
static_cast<const SparseColumnIter<BinIdxType>&>(col_input);
void CheckSparseColumn(SparseColumnIter<BinIdxType>* p_col, const GHistIndexMatrix& gmat) {
auto& col = *p_col;
size_t n_samples = gmat.row_ptr.size() - 1;
ASSERT_EQ(col.Size(), gmat.index.Size());
for (auto i = 0ull; i < col.Size(); i++) {
ASSERT_EQ(gmat.index[gmat.row_ptr[col.GetRowIdx(i)]], col.GetGlobalBinIdx(i));
}
for (auto i = 0ull; i < n_samples; i++) {
if (col[i] == Column<BinIdxType>::kMissingId) {
auto beg = gmat.row_ptr[i];
auto end = gmat.row_ptr[i + 1];
ASSERT_EQ(end - beg, 0);
}
}
}
TEST(SparseColumn, Test) {
@@ -72,26 +71,17 @@ TEST(SparseColumn, Test) {
for (auto const& page : dmat->GetBatches<SparsePage>()) {
column_matrix.Init(page, gmat, 1.0, common::OmpGetNumThreads(0));
}
switch (column_matrix.GetTypeSize()) {
case kUint8BinsTypeSize: {
auto col = column_matrix.SparseColumn<uint8_t>(0, 0);
CheckSparseColumn(col, gmat);
} break;
case kUint16BinsTypeSize: {
auto col = column_matrix.SparseColumn<uint16_t>(0, 0);
CheckSparseColumn(col, gmat);
} break;
case kUint32BinsTypeSize: {
auto col = column_matrix.SparseColumn<uint32_t>(0, 0);
CheckSparseColumn(col, gmat);
} break;
}
common::DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) {
using T = decltype(dtype);
auto col = column_matrix.SparseColumn<T>(0, 0);
CheckSparseColumn(&col, gmat);
});
}
}
template <typename BinIdxType>
inline void CheckColumWithMissingValue(const DenseColumnIter<BinIdxType, true>& col,
const GHistIndexMatrix& gmat) {
void CheckColumWithMissingValue(const DenseColumnIter<BinIdxType, true>& col,
const GHistIndexMatrix& gmat) {
for (auto i = 0ull; i < col.Size(); i++) {
if (col.IsMissing(i)) continue;
EXPECT_EQ(gmat.index[gmat.row_ptr[i]], col.GetGlobalBinIdx(i));
@@ -110,20 +100,11 @@ TEST(DenseColumnWithMissing, Test) {
column_matrix.Init(page, gmat, 0.2, common::OmpGetNumThreads(0));
}
ASSERT_TRUE(column_matrix.AnyMissing());
switch (column_matrix.GetTypeSize()) {
case kUint8BinsTypeSize: {
auto col = column_matrix.DenseColumn<uint8_t, true>(0);
CheckColumWithMissingValue(col, gmat);
} break;
case kUint16BinsTypeSize: {
auto col = column_matrix.DenseColumn<uint16_t, true>(0);
CheckColumWithMissingValue(col, gmat);
} break;
case kUint32BinsTypeSize: {
auto col = column_matrix.DenseColumn<uint32_t, true>(0);
CheckColumWithMissingValue(col, gmat);
} break;
}
DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) {
using T = decltype(dtype);
auto col = column_matrix.DenseColumn<T, true>(0);
CheckColumWithMissingValue(col, gmat);
});
}
}