/*! * Copyright 2018-2022 by XGBoost Contributors */ #include #include #include "../../../src/common/column_matrix.h" #include "../helpers.h" namespace xgboost { namespace common { TEST(DenseColumn, Test) { int32_t max_num_bins[] = {static_cast(std::numeric_limits::max()) + 1, static_cast(std::numeric_limits::max()) + 1, static_cast(std::numeric_limits::max()) + 2}; BinTypeSize last{kUint8BinsTypeSize}; for (int32_t max_num_bin : max_num_bins) { auto dmat = RandomDataGenerator(100, 10, 0.0).GenerateDMatrix(); auto sparse_thresh = 0.2; GHistIndexMatrix gmat{dmat.get(), max_num_bin, sparse_thresh, false, common::OmpGetNumThreads(0)}; ColumnMatrix column_matrix; for (auto const& page : dmat->GetBatches()) { column_matrix.Init(page, gmat, sparse_thresh, common::OmpGetNumThreads(0)); } ASSERT_GE(column_matrix.GetTypeSize(), last); ASSERT_LE(column_matrix.GetTypeSize(), kUint32BinsTypeSize); last = column_matrix.GetTypeSize(); ASSERT_FALSE(column_matrix.AnyMissing()); for (auto i = 0ull; i < dmat->Info().num_row_; i++) { for (auto j = 0ull; j < dmat->Info().num_col_; j++) { DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) { using T = decltype(dtype); auto col = column_matrix.DenseColumn(j); ASSERT_EQ(gmat.index[i * dmat->Info().num_col_ + j], col.GetGlobalBinIdx(i)); }); } } } } template void CheckSparseColumn(SparseColumnIter* p_col, const GHistIndexMatrix& gmat) { auto& col = *p_col; size_t n_samples = gmat.row_ptr.size() - 1; ASSERT_EQ(col.Size(), gmat.index.Size()); for (auto i = 0ull; i < col.Size(); i++) { ASSERT_EQ(gmat.index[gmat.row_ptr[col.GetRowIdx(i)]], col.GetGlobalBinIdx(i)); } for (auto i = 0ull; i < n_samples; i++) { if (col[i] == Column::kMissingId) { auto beg = gmat.row_ptr[i]; auto end = gmat.row_ptr[i + 1]; ASSERT_EQ(end - beg, 0); } } } TEST(SparseColumn, Test) { int32_t max_num_bins[] = {static_cast(std::numeric_limits::max()) + 1, static_cast(std::numeric_limits::max()) + 1, static_cast(std::numeric_limits::max()) + 2}; for (int32_t max_num_bin : max_num_bins) { auto dmat = RandomDataGenerator(100, 1, 0.85).GenerateDMatrix(); GHistIndexMatrix gmat{dmat.get(), max_num_bin, 0.5f, false, common::OmpGetNumThreads(0)}; ColumnMatrix column_matrix; for (auto const& page : dmat->GetBatches()) { column_matrix.Init(page, gmat, 1.0, common::OmpGetNumThreads(0)); } common::DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) { using T = decltype(dtype); auto col = column_matrix.SparseColumn(0, 0); CheckSparseColumn(&col, gmat); }); } } template void CheckColumWithMissingValue(const DenseColumnIter& col, const GHistIndexMatrix& gmat) { for (auto i = 0ull; i < col.Size(); i++) { if (col.IsMissing(i)) continue; EXPECT_EQ(gmat.index[gmat.row_ptr[i]], col.GetGlobalBinIdx(i)); } } TEST(DenseColumnWithMissing, Test) { int32_t max_num_bins[] = {static_cast(std::numeric_limits::max()) + 1, static_cast(std::numeric_limits::max()) + 1, static_cast(std::numeric_limits::max()) + 2}; for (int32_t max_num_bin : max_num_bins) { auto dmat = RandomDataGenerator(100, 1, 0.5).GenerateDMatrix(); GHistIndexMatrix gmat(dmat.get(), max_num_bin, 0.2, false, common::OmpGetNumThreads(0)); ColumnMatrix column_matrix; for (auto const& page : dmat->GetBatches()) { column_matrix.Init(page, gmat, 0.2, common::OmpGetNumThreads(0)); } ASSERT_TRUE(column_matrix.AnyMissing()); DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) { using T = decltype(dtype); auto col = column_matrix.DenseColumn(0); CheckColumWithMissingValue(col, gmat); }); } } } // namespace common } // namespace xgboost