Refactor for GHistIndex. (#7923)

* Pass sparse page as adapter, which prepares for quantile dmatrix.
* Remove old external memory code like `rbegin` and extra `Init` function.
* Simplify type dispatch.
This commit is contained in:
Jiaming Yuan
2022-05-23 23:04:53 +08:00
committed by GitHub
parent d314680a15
commit 18a38f7ca0
6 changed files with 113 additions and 157 deletions

View File

@@ -107,22 +107,5 @@ TEST(DenseColumnWithMissing, Test) {
});
}
}
void TestGHistIndexMatrixCreation(size_t nthreads) {
size_t constexpr kPageSize = 1024, kEntriesPerCol = 3;
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
/* This should create multiple sparse pages */
std::unique_ptr<DMatrix> dmat{CreateSparsePageDMatrix(kEntries)};
GHistIndexMatrix gmat(dmat.get(), 256, 0.5f, false, common::OmpGetNumThreads(nthreads));
}
TEST(HistIndexCreationWithExternalMemory, Test) {
// Vary the number of threads to make sure that the last batch
// is distributed properly to the available number of threads
// in the thread pool
TestGHistIndexMatrixCreation(20);
TestGHistIndexMatrixCreation(30);
TestGHistIndexMatrixCreation(40);
}
} // namespace common
} // namespace xgboost

View File

@@ -44,9 +44,7 @@ TEST(GradientIndex, FromCategoricalBasic) {
h_ft.resize(kCols, FeatureType::kCategorical);
BatchParam p(max_bins, 0.8);
GHistIndexMatrix gidx;
gidx.Init(m.get(), max_bins, p.sparse_thresh, false, common::OmpGetNumThreads(0), {});
GHistIndexMatrix gidx(m.get(), max_bins, p.sparse_thresh, false, common::OmpGetNumThreads(0), {});
auto x_copy = x;
std::sort(x_copy.begin(), x_copy.end());

View File

@@ -413,10 +413,16 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx) {
single_build.Reset(total_bins, batch_param, common::OmpGetNumThreads(0), 1, false);
SparsePage concat;
GHistIndexMatrix gmat;
std::vector<float> hess(m->Info().num_row_, 1.0f);
gmat.Init(m.get(), batch_param.max_bin, std::numeric_limits<double>::quiet_NaN(), false,
common::OmpGetNumThreads(0), hess);
for (auto const& page : m->GetBatches<SparsePage>()) {
concat.Push(page);
}
auto cut = common::SketchOnDMatrix(m.get(), batch_param.max_bin, common::OmpGetNumThreads(0),
false, hess);
GHistIndexMatrix gmat;
gmat.Init(concat, {}, cut, batch_param.max_bin, false, std::numeric_limits<double>::quiet_NaN(),
common::OmpGetNumThreads(0));
single_build.BuildHist(0, gmat, &tree, row_set_collection, nodes, {}, h_gpair);
single_page = single_build.Histogram()[0];
}