Merge branch 'master' into dev-hui

This commit is contained in:
amdsc21
2023-03-08 00:39:33 +01:00
221 changed files with 3122 additions and 1486 deletions

View File

@@ -119,13 +119,13 @@ TEST(ArrayInterface, TrivialDim) {
}
TEST(ArrayInterface, ToDType) {
static_assert(ToDType<float>::kType == ArrayInterfaceHandler::kF4, "");
static_assert(ToDType<double>::kType == ArrayInterfaceHandler::kF8, "");
static_assert(ToDType<float>::kType == ArrayInterfaceHandler::kF4);
static_assert(ToDType<double>::kType == ArrayInterfaceHandler::kF8);
static_assert(ToDType<uint32_t>::kType == ArrayInterfaceHandler::kU4, "");
static_assert(ToDType<uint64_t>::kType == ArrayInterfaceHandler::kU8, "");
static_assert(ToDType<uint32_t>::kType == ArrayInterfaceHandler::kU4);
static_assert(ToDType<uint64_t>::kType == ArrayInterfaceHandler::kU8);
static_assert(ToDType<int32_t>::kType == ArrayInterfaceHandler::kI4, "");
static_assert(ToDType<int64_t>::kType == ArrayInterfaceHandler::kI8, "");
static_assert(ToDType<int32_t>::kType == ArrayInterfaceHandler::kI4);
static_assert(ToDType<int64_t>::kType == ArrayInterfaceHandler::kI8);
}
} // namespace xgboost

View File

@@ -21,7 +21,7 @@ TEST(SparsePage, PushCSC) {
offset = {0, 1, 4};
for (size_t i = 0; i < offset.back(); ++i) {
data.emplace_back(Entry(i, 0.1f));
data.emplace_back(i, 0.1f);
}
SparsePage other;

View File

@@ -68,6 +68,30 @@ TEST(GradientIndex, FromCategoricalBasic) {
}
}
TEST(GradientIndex, FromCategoricalLarge) {
size_t constexpr kRows = 1000, kCats = 512, kCols = 1;
bst_bin_t max_bins = 8;
auto x = GenerateRandomCategoricalSingleColumn(kRows, kCats);
auto m = GetDMatrixFromData(x, kRows, 1);
Context ctx;
auto &h_ft = m->Info().feature_types.HostVector();
h_ft.resize(kCols, FeatureType::kCategorical);
BatchParam p{max_bins, 0.8};
{
GHistIndexMatrix gidx(m.get(), max_bins, p.sparse_thresh, false, AllThreadsForTest(), {});
ASSERT_TRUE(gidx.index.GetBinTypeSize() == common::kUint16BinsTypeSize);
}
{
for (auto const &page : m->GetBatches<GHistIndexMatrix>(p)) {
common::HistogramCuts cut = page.cut;
GHistIndexMatrix gidx{m->Info(), std::move(cut), max_bins};
ASSERT_EQ(gidx.MaxNumBinPerFeat(), kCats);
}
}
}
TEST(GradientIndex, PushBatch) {
size_t constexpr kRows = 64, kCols = 4;
bst_bin_t max_bins = 64;

View File

@@ -189,8 +189,8 @@ TEST(SimpleCSRSource, FromColumnarSparse) {
auto& mask = column_bitfields[0];
mask.resize(8);
for (size_t j = 0; j < mask.size(); ++j) {
mask[j] = ~0;
for (auto && j : mask) {
j = ~0;
}
// the 2^th entry of first column is invalid
// [0 0 0 0 0 1 0 0]
@@ -201,8 +201,8 @@ TEST(SimpleCSRSource, FromColumnarSparse) {
auto& mask = column_bitfields[1];
mask.resize(8);
for (size_t j = 0; j < mask.size(); ++j) {
mask[j] = ~0;
for (auto && j : mask) {
j = ~0;
}
// the 19^th entry of second column is invalid
// [~0~], [~0~], [0 0 0 0 1 0 0 0]

View File

@@ -96,7 +96,7 @@ void TestRetainPage() {
// make sure it's const and the caller can not modify the content of page.
for (auto& page : m->GetBatches<Page>()) {
static_assert(std::is_const<std::remove_reference_t<decltype(page)>>::value, "");
static_assert(std::is_const<std::remove_reference_t<decltype(page)>>::value);
}
}

View File

@@ -1,5 +1,6 @@
// Copyright by Contributors
/**
* Copyright 2019-2023 by XGBoost Contributors
*/
#include "../../../src/common/compressed_iterator.h"
#include "../../../src/data/ellpack_page.cuh"
#include "../../../src/data/sparse_page_dmatrix.h"
@@ -69,7 +70,7 @@ TEST(SparsePageDMatrix, RetainEllpackPage) {
std::vector<std::shared_ptr<EllpackPage const>> iterators;
for (auto it = begin; it != end; ++it) {
iterators.push_back(it.Page());
gidx_buffers.emplace_back(HostDeviceVector<common::CompressedByteT>{});
gidx_buffers.emplace_back();
gidx_buffers.back().Resize((*it).Impl()->gidx_buffer.Size());
gidx_buffers.back().Copy((*it).Impl()->gidx_buffer);
}
@@ -87,7 +88,7 @@ TEST(SparsePageDMatrix, RetainEllpackPage) {
// make sure it's const and the caller can not modify the content of page.
for (auto& page : m->GetBatches<EllpackPage>({0, 32})) {
static_assert(std::is_const<std::remove_reference_t<decltype(page)>>::value, "");
static_assert(std::is_const<std::remove_reference_t<decltype(page)>>::value);
}
// The above iteration clears out all references inside DMatrix.