Small cleanup to Column. (#7898)

* Define forward iterator to hide the internal state.
This commit is contained in:
Jiaming Yuan
2022-05-15 12:39:10 +08:00
committed by GitHub
parent ee382c4153
commit 1baad8650c
3 changed files with 131 additions and 152 deletions

View File

@@ -27,38 +27,33 @@ TEST(DenseColumn, Test) {
for (auto i = 0ull; i < dmat->Info().num_row_; i++) {
for (auto j = 0ull; j < dmat->Info().num_col_; j++) {
switch (column_matrix.GetTypeSize()) {
case kUint8BinsTypeSize: {
auto col = column_matrix.GetColumn<uint8_t, false>(j);
ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j],
(*col.get()).GetGlobalBinIdx(i));
}
break;
case kUint16BinsTypeSize: {
auto col = column_matrix.GetColumn<uint16_t, false>(j);
ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j],
(*col.get()).GetGlobalBinIdx(i));
}
break;
case kUint32BinsTypeSize: {
auto col = column_matrix.GetColumn<uint32_t, false>(j);
ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j],
(*col.get()).GetGlobalBinIdx(i));
}
break;
switch (column_matrix.GetTypeSize()) {
case kUint8BinsTypeSize: {
auto col = column_matrix.DenseColumn<uint8_t, false>(j);
ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j], col.GetGlobalBinIdx(i));
} break;
case kUint16BinsTypeSize: {
auto col = column_matrix.DenseColumn<uint16_t, false>(j);
ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j], col.GetGlobalBinIdx(i));
} break;
case kUint32BinsTypeSize: {
auto col = column_matrix.DenseColumn<uint32_t, false>(j);
ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j], col.GetGlobalBinIdx(i));
} break;
}
}
}
}
}
template<typename BinIdxType>
inline void CheckSparseColumn(const Column<BinIdxType>& col_input, const GHistIndexMatrix& gmat) {
const SparseColumn<BinIdxType>& col = static_cast<const SparseColumn<BinIdxType>& >(col_input);
template <typename BinIdxType>
inline void CheckSparseColumn(const SparseColumnIter<BinIdxType>& col_input,
const GHistIndexMatrix& gmat) {
const SparseColumnIter<BinIdxType>& col =
static_cast<const SparseColumnIter<BinIdxType>&>(col_input);
ASSERT_EQ(col.Size(), gmat.index.Size());
for (auto i = 0ull; i < col.Size(); i++) {
ASSERT_EQ(gmat.index[gmat.row_ptr[col.GetRowIdx(i)]],
col.GetGlobalBinIdx(i));
ASSERT_EQ(gmat.index[gmat.row_ptr[col.GetRowIdx(i)]], col.GetGlobalBinIdx(i));
}
}
@@ -75,32 +70,27 @@ TEST(SparseColumn, Test) {
}
switch (column_matrix.GetTypeSize()) {
case kUint8BinsTypeSize: {
auto col = column_matrix.GetColumn<uint8_t, true>(0);
CheckSparseColumn(*col.get(), gmat);
}
break;
auto col = column_matrix.SparseColumn<uint8_t>(0, 0);
CheckSparseColumn(col, gmat);
} break;
case kUint16BinsTypeSize: {
auto col = column_matrix.GetColumn<uint16_t, true>(0);
CheckSparseColumn(*col.get(), gmat);
}
break;
auto col = column_matrix.SparseColumn<uint16_t>(0, 0);
CheckSparseColumn(col, gmat);
} break;
case kUint32BinsTypeSize: {
auto col = column_matrix.GetColumn<uint32_t, true>(0);
CheckSparseColumn(*col.get(), gmat);
}
break;
auto col = column_matrix.SparseColumn<uint32_t>(0, 0);
CheckSparseColumn(col, gmat);
} break;
}
}
}
template<typename BinIdxType>
inline void CheckColumWithMissingValue(const Column<BinIdxType>& col_input,
template <typename BinIdxType>
inline void CheckColumWithMissingValue(const DenseColumnIter<BinIdxType, true>& col,
const GHistIndexMatrix& gmat) {
const DenseColumn<BinIdxType, true>& col = static_cast<const DenseColumn<BinIdxType, true>& >(col_input);
for (auto i = 0ull; i < col.Size(); i++) {
if (col.IsMissing(i)) continue;
EXPECT_EQ(gmat.index[gmat.row_ptr[i]],
col.GetGlobalBinIdx(i));
EXPECT_EQ(gmat.index[gmat.row_ptr[i]], col.GetGlobalBinIdx(i));
}
}
@@ -117,20 +107,17 @@ TEST(DenseColumnWithMissing, Test) {
}
switch (column_matrix.GetTypeSize()) {
case kUint8BinsTypeSize: {
auto col = column_matrix.GetColumn<uint8_t, true>(0);
CheckColumWithMissingValue(*col.get(), gmat);
}
break;
auto col = column_matrix.DenseColumn<uint8_t, true>(0);
CheckColumWithMissingValue(col, gmat);
} break;
case kUint16BinsTypeSize: {
auto col = column_matrix.GetColumn<uint16_t, true>(0);
CheckColumWithMissingValue(*col.get(), gmat);
}
break;
auto col = column_matrix.DenseColumn<uint16_t, true>(0);
CheckColumWithMissingValue(col, gmat);
} break;
case kUint32BinsTypeSize: {
auto col = column_matrix.GetColumn<uint32_t, true>(0);
CheckColumWithMissingValue(*col.get(), gmat);
}
break;
auto col = column_matrix.DenseColumn<uint32_t, true>(0);
CheckColumWithMissingValue(col, gmat);
} break;
}
}
}