Prepare external memory support for hist. (#7638)

This PR prepares the GHistIndexMatrix to host the column matrix which is used by the hist tree method by accepting sparse_threshold parameter.

Some cleanups are made to ensure the correct batch param is being passed into DMatrix along with some additional tests for correctness of SimpleDMatrix.
This commit is contained in:
Jiaming Yuan
2022-02-10 16:58:02 +08:00
committed by GitHub
parent 87c01f49d8
commit 2775c2a1ab
24 changed files with 368 additions and 201 deletions

View File

@@ -155,18 +155,19 @@ class QuantileHistMock : public QuantileHistMaker {
std::vector<GradientPair> row_gpairs =
{ {1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f},
{0.27f, 0.29f}, {0.37f, 0.39f}, {-0.47f, 0.49f}, {0.57f, 0.59f} };
size_t constexpr kMaxBins = 4;
int32_t constexpr kMaxBins = 4;
// try out different sparsity to get different number of missing values
for (double sparsity : {0.0, 0.1, 0.2}) {
// kNRows samples with kNCols features
auto dmat = RandomDataGenerator(kNRows, kNCols, sparsity).Seed(3).GenerateDMatrix();
GHistIndexMatrix gmat(dmat.get(), kMaxBins, false, common::OmpGetNumThreads(0));
float sparse_th = 0.0;
GHistIndexMatrix gmat{dmat.get(), kMaxBins, sparse_th, false, common::OmpGetNumThreads(0)};
ColumnMatrix cm;
// treat everything as dense, as this is what we intend to test here
cm.Init(gmat, 0.0, common::OmpGetNumThreads(0));
cm.Init(gmat, sparse_th, common::OmpGetNumThreads(0));
RealImpl::InitData(gmat, *dmat, tree, &row_gpairs);
const size_t num_row = dmat->Info().num_row_;
// split by feature 0
@@ -247,8 +248,8 @@ class QuantileHistMock : public QuantileHistMaker {
static size_t GetNumColumns() { return kNCols; }
void TestInitData() {
size_t constexpr kMaxBins = 4;
GHistIndexMatrix gmat(dmat_.get(), kMaxBins, false, common::OmpGetNumThreads(0));
int32_t constexpr kMaxBins = 4;
GHistIndexMatrix gmat{dmat_.get(), kMaxBins, 0.0f, false, common::OmpGetNumThreads(0)};
RegTree tree = RegTree();
tree.param.UpdateAllowUnknown(cfg_);
@@ -264,8 +265,8 @@ class QuantileHistMock : public QuantileHistMaker {
}
void TestInitDataSampling() {
size_t constexpr kMaxBins = 4;
GHistIndexMatrix gmat(dmat_.get(), kMaxBins, false, common::OmpGetNumThreads(0));
int32_t constexpr kMaxBins = 4;
GHistIndexMatrix gmat{dmat_.get(), kMaxBins, 0.0f, false, common::OmpGetNumThreads(0)};
RegTree tree = RegTree();
tree.param.UpdateAllowUnknown(cfg_);