Reducing memory consumption for 'hist' method on CPU (#5334)
This commit is contained in:
@@ -347,5 +347,106 @@ TEST(hist_util, SparseCutsExternalMemory) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(hist_util, IndexBinBound) {
|
||||
uint64_t bin_sizes[] = { static_cast<uint64_t>(std::numeric_limits<uint8_t>::max()) + 1,
|
||||
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 1,
|
||||
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2 };
|
||||
BinTypeSize expected_bin_type_sizes[] = {UINT8_BINS_TYPE_SIZE,
|
||||
UINT16_BINS_TYPE_SIZE,
|
||||
UINT32_BINS_TYPE_SIZE};
|
||||
size_t constexpr kRows = 100;
|
||||
size_t constexpr kCols = 10;
|
||||
|
||||
size_t bin_id = 0;
|
||||
for (auto max_bin : bin_sizes) {
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatix();
|
||||
|
||||
common::GHistIndexMatrix hmat;
|
||||
hmat.Init(p_fmat.get(), max_bin);
|
||||
EXPECT_EQ(hmat.index.size(), kRows*kCols);
|
||||
EXPECT_EQ(expected_bin_type_sizes[bin_id++], hmat.index.getBinTypeSize());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(hist_util, SparseIndexBinBound) {
|
||||
uint64_t bin_sizes[] = { static_cast<uint64_t>(std::numeric_limits<uint8_t>::max()) + 1,
|
||||
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 1,
|
||||
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2 };
|
||||
BinTypeSize expected_bin_type_sizes[] = { UINT32_BINS_TYPE_SIZE,
|
||||
UINT32_BINS_TYPE_SIZE,
|
||||
UINT32_BINS_TYPE_SIZE };
|
||||
size_t constexpr kRows = 100;
|
||||
size_t constexpr kCols = 10;
|
||||
|
||||
size_t bin_id = 0;
|
||||
for (auto max_bin : bin_sizes) {
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0.2).GenerateDMatix();
|
||||
common::GHistIndexMatrix hmat;
|
||||
hmat.Init(p_fmat.get(), max_bin);
|
||||
EXPECT_EQ(expected_bin_type_sizes[bin_id++], hmat.index.getBinTypeSize());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CheckIndexData(T* data_ptr, uint32_t* offsets,
|
||||
const common::GHistIndexMatrix& hmat, size_t n_cols) {
|
||||
for (size_t i = 0; i < hmat.index.size(); ++i) {
|
||||
EXPECT_EQ(data_ptr[i] + offsets[i % n_cols], hmat.index[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(hist_util, IndexBinData) {
|
||||
uint64_t constexpr kBinSizes[] = { static_cast<uint64_t>(std::numeric_limits<uint8_t>::max()) + 1,
|
||||
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 1,
|
||||
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2 };
|
||||
size_t constexpr kRows = 100;
|
||||
size_t constexpr kCols = 10;
|
||||
|
||||
size_t bin_id = 0;
|
||||
for (auto max_bin : kBinSizes) {
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatix();
|
||||
common::GHistIndexMatrix hmat;
|
||||
hmat.Init(p_fmat.get(), max_bin);
|
||||
uint32_t* offsets = hmat.index.offset();
|
||||
EXPECT_EQ(hmat.index.size(), kRows*kCols);
|
||||
switch (max_bin) {
|
||||
case kBinSizes[0]:
|
||||
CheckIndexData(hmat.index.data<uint8_t>(),
|
||||
offsets, hmat, kCols);
|
||||
break;
|
||||
case kBinSizes[1]:
|
||||
CheckIndexData(hmat.index.data<uint16_t>(),
|
||||
offsets, hmat, kCols);
|
||||
break;
|
||||
case kBinSizes[2]:
|
||||
CheckIndexData(hmat.index.data<uint32_t>(),
|
||||
offsets, hmat, kCols);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(hist_util, SparseIndexBinData) {
|
||||
uint64_t bin_sizes[] = { static_cast<uint64_t>(std::numeric_limits<uint8_t>::max()) + 1,
|
||||
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 1,
|
||||
static_cast<uint64_t>(std::numeric_limits<uint16_t>::max()) + 2 };
|
||||
size_t constexpr kRows = 100;
|
||||
size_t constexpr kCols = 10;
|
||||
|
||||
size_t bin_id = 0;
|
||||
for (auto max_bin : bin_sizes) {
|
||||
auto p_fmat = RandomDataGenerator(kRows, kCols, 0.2).GenerateDMatix();
|
||||
common::GHistIndexMatrix hmat;
|
||||
hmat.Init(p_fmat.get(), max_bin);
|
||||
EXPECT_EQ(hmat.index.offset(), nullptr);
|
||||
|
||||
uint32_t* data_ptr = hmat.index.data<uint32_t>();
|
||||
for (size_t i = 0; i < hmat.index.size(); ++i) {
|
||||
EXPECT_EQ(data_ptr[i], hmat.index[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user