Fix and cleanup for column matrix. (#7901)
* Fix missed type dispatching for dense columns with missing values. * Code cleanup to reduce special cases. * Reduce memory usage.
This commit is contained in:
@@ -15,6 +15,7 @@ TEST(DenseColumn, Test) {
|
||||
int32_t max_num_bins[] = {static_cast<int32_t>(std::numeric_limits<uint8_t>::max()) + 1,
|
||||
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 1,
|
||||
static_cast<int32_t>(std::numeric_limits<uint16_t>::max()) + 2};
|
||||
BinTypeSize last{kUint8BinsTypeSize};
|
||||
for (int32_t max_num_bin : max_num_bins) {
|
||||
auto dmat = RandomDataGenerator(100, 10, 0.0).GenerateDMatrix();
|
||||
auto sparse_thresh = 0.2;
|
||||
@@ -24,7 +25,10 @@ TEST(DenseColumn, Test) {
|
||||
for (auto const& page : dmat->GetBatches<SparsePage>()) {
|
||||
column_matrix.Init(page, gmat, sparse_thresh, common::OmpGetNumThreads(0));
|
||||
}
|
||||
|
||||
ASSERT_GE(column_matrix.GetTypeSize(), last);
|
||||
ASSERT_LE(column_matrix.GetTypeSize(), kUint32BinsTypeSize);
|
||||
last = column_matrix.GetTypeSize();
|
||||
ASSERT_FALSE(column_matrix.AnyMissing());
|
||||
for (auto i = 0ull; i < dmat->Info().num_row_; i++) {
|
||||
for (auto j = 0ull; j < dmat->Info().num_col_; j++) {
|
||||
switch (column_matrix.GetTypeSize()) {
|
||||
@@ -105,6 +109,7 @@ TEST(DenseColumnWithMissing, Test) {
|
||||
for (auto const& page : dmat->GetBatches<SparsePage>()) {
|
||||
column_matrix.Init(page, gmat, 0.2, common::OmpGetNumThreads(0));
|
||||
}
|
||||
ASSERT_TRUE(column_matrix.AnyMissing());
|
||||
switch (column_matrix.GetTypeSize()) {
|
||||
case kUint8BinsTypeSize: {
|
||||
auto col = column_matrix.DenseColumn<uint8_t, true>(0);
|
||||
|
||||
Reference in New Issue
Block a user