Fix CPU hist init for sparse dataset. (#4625)
* Fix CPU hist init for sparse dataset. * Implement sparse histogram cut. * Allow empty features. * Fix windows build, don't use sparse in distributed environment. * Comments. * Smaller threshold. * Fix windows omp. * Fix msvc lambda capture. * Fix MSVC macro. * Fix MSVC initialization list. * Fix MSVC initialization list x2. * Preserve categorical feature behavior. * Rename matrix to sparse cuts. * Reuse UseGroup. * Check for categorical data when adding cut. Co-Authored-By: Philip Hyunsu Cho <chohyu01@cs.washington.edu> * Sanity check. * Fix comments. * Fix comment.
This commit is contained in:
committed by
Philip Hyunsu Cho
parent
b7a1f22d24
commit
d9a47794a5
@@ -9,15 +9,7 @@
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
class HistCutMatrixMock : public HistCutMatrix {
|
||||
public:
|
||||
size_t SearchGroupIndFromBaseRow(
|
||||
std::vector<bst_uint> const& group_ptr, size_t const base_rowid) {
|
||||
return HistCutMatrix::SearchGroupIndFromBaseRow(group_ptr, base_rowid);
|
||||
}
|
||||
};
|
||||
|
||||
TEST(HistCutMatrix, SearchGroupInd) {
|
||||
TEST(CutsBuilder, SearchGroupInd) {
|
||||
size_t constexpr kNumGroups = 4;
|
||||
size_t constexpr kNumRows = 17;
|
||||
size_t constexpr kNumCols = 15;
|
||||
@@ -34,18 +26,102 @@ TEST(HistCutMatrix, SearchGroupInd) {
|
||||
p_mat->Info().SetInfo(
|
||||
"group", group.data(), DataType::kUInt32, kNumGroups);
|
||||
|
||||
HistCutMatrixMock hmat;
|
||||
HistogramCuts hmat;
|
||||
|
||||
size_t group_ind = hmat.SearchGroupIndFromBaseRow(p_mat->Info().group_ptr_, 0);
|
||||
size_t group_ind = CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 0);
|
||||
ASSERT_EQ(group_ind, 0);
|
||||
|
||||
group_ind = hmat.SearchGroupIndFromBaseRow(p_mat->Info().group_ptr_, 5);
|
||||
group_ind = CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 5);
|
||||
ASSERT_EQ(group_ind, 2);
|
||||
|
||||
EXPECT_ANY_THROW(hmat.SearchGroupIndFromBaseRow(p_mat->Info().group_ptr_, 17));
|
||||
EXPECT_ANY_THROW(CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 17));
|
||||
|
||||
delete pp_mat;
|
||||
}
|
||||
|
||||
namespace {
|
||||
class SparseCutsWrapper : public SparseCuts {
|
||||
public:
|
||||
std::vector<uint32_t> const& ColPtrs() const { return p_cuts_->Ptrs(); }
|
||||
std::vector<float> const& ColValues() const { return p_cuts_->Values(); }
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(SparseCuts, SingleThreadedBuild) {
|
||||
size_t constexpr kRows = 267;
|
||||
size_t constexpr kCols = 31;
|
||||
size_t constexpr kBins = 256;
|
||||
|
||||
// Dense matrix.
|
||||
auto pp_mat = CreateDMatrix(kRows, kCols, 0);
|
||||
DMatrix* p_fmat = (*pp_mat).get();
|
||||
|
||||
common::GHistIndexMatrix hmat;
|
||||
hmat.Init(p_fmat, kBins);
|
||||
|
||||
HistogramCuts cuts;
|
||||
SparseCuts indices(&cuts);
|
||||
auto const& page = *(p_fmat->GetColumnBatches().begin());
|
||||
indices.SingleThreadBuild(page, p_fmat->Info(), kBins, false, 0, page.Size(), 0);
|
||||
|
||||
ASSERT_EQ(hmat.cut.Ptrs().size(), cuts.Ptrs().size());
|
||||
ASSERT_EQ(hmat.cut.Ptrs(), cuts.Ptrs());
|
||||
ASSERT_EQ(hmat.cut.Values(), cuts.Values());
|
||||
ASSERT_EQ(hmat.cut.MinValues(), cuts.MinValues());
|
||||
|
||||
delete pp_mat;
|
||||
}
|
||||
|
||||
TEST(SparseCuts, MultiThreadedBuild) {
|
||||
size_t constexpr kRows = 17;
|
||||
size_t constexpr kCols = 15;
|
||||
size_t constexpr kBins = 255;
|
||||
|
||||
omp_ulong ori_nthreads = omp_get_max_threads();
|
||||
omp_set_num_threads(16);
|
||||
|
||||
auto Compare =
|
||||
#if defined(_MSC_VER) // msvc fails to capture
|
||||
[kBins](DMatrix* p_fmat) {
|
||||
#else
|
||||
[](DMatrix* p_fmat) {
|
||||
#endif
|
||||
HistogramCuts threaded_container;
|
||||
SparseCuts threaded_indices(&threaded_container);
|
||||
threaded_indices.Build(p_fmat, kBins);
|
||||
|
||||
HistogramCuts container;
|
||||
SparseCuts indices(&container);
|
||||
auto const& page = *(p_fmat->GetColumnBatches().begin());
|
||||
indices.SingleThreadBuild(page, p_fmat->Info(), kBins, false, 0, page.Size(), 0);
|
||||
|
||||
ASSERT_EQ(container.Ptrs().size(), threaded_container.Ptrs().size());
|
||||
ASSERT_EQ(container.Values().size(), threaded_container.Values().size());
|
||||
|
||||
for (uint32_t i = 0; i < container.Ptrs().size(); ++i) {
|
||||
ASSERT_EQ(container.Ptrs()[i], threaded_container.Ptrs()[i]);
|
||||
}
|
||||
for (uint32_t i = 0; i < container.Values().size(); ++i) {
|
||||
ASSERT_EQ(container.Values()[i], threaded_container.Values()[i]);
|
||||
}
|
||||
};
|
||||
|
||||
{
|
||||
auto pp_mat = CreateDMatrix(kRows, kCols, 0);
|
||||
DMatrix* p_fmat = (*pp_mat).get();
|
||||
Compare(p_fmat);
|
||||
delete pp_mat;
|
||||
}
|
||||
|
||||
{
|
||||
auto pp_mat = CreateDMatrix(kRows, kCols, 0.0001);
|
||||
DMatrix* p_fmat = (*pp_mat).get();
|
||||
Compare(p_fmat);
|
||||
delete pp_mat;
|
||||
}
|
||||
|
||||
omp_set_num_threads(ori_nthreads);
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user