Prepare external memory support for hist. (#7638)
This PR prepares the GHistIndexMatrix to host the column matrix which is used by the hist tree method by accepting sparse_threshold parameter. Some cleanups are made to ensure the correct batch param is being passed into DMatrix along with some additional tests for correctness of SimpleDMatrix.
This commit is contained in:
@@ -155,18 +155,19 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
std::vector<GradientPair> row_gpairs =
|
||||
{ {1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f},
|
||||
{0.27f, 0.29f}, {0.37f, 0.39f}, {-0.47f, 0.49f}, {0.57f, 0.59f} };
|
||||
size_t constexpr kMaxBins = 4;
|
||||
int32_t constexpr kMaxBins = 4;
|
||||
|
||||
// try out different sparsity to get different number of missing values
|
||||
for (double sparsity : {0.0, 0.1, 0.2}) {
|
||||
// kNRows samples with kNCols features
|
||||
auto dmat = RandomDataGenerator(kNRows, kNCols, sparsity).Seed(3).GenerateDMatrix();
|
||||
|
||||
GHistIndexMatrix gmat(dmat.get(), kMaxBins, false, common::OmpGetNumThreads(0));
|
||||
float sparse_th = 0.0;
|
||||
GHistIndexMatrix gmat{dmat.get(), kMaxBins, sparse_th, false, common::OmpGetNumThreads(0)};
|
||||
ColumnMatrix cm;
|
||||
|
||||
// treat everything as dense, as this is what we intend to test here
|
||||
cm.Init(gmat, 0.0, common::OmpGetNumThreads(0));
|
||||
cm.Init(gmat, sparse_th, common::OmpGetNumThreads(0));
|
||||
RealImpl::InitData(gmat, *dmat, tree, &row_gpairs);
|
||||
const size_t num_row = dmat->Info().num_row_;
|
||||
// split by feature 0
|
||||
@@ -247,8 +248,8 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
static size_t GetNumColumns() { return kNCols; }
|
||||
|
||||
void TestInitData() {
|
||||
size_t constexpr kMaxBins = 4;
|
||||
GHistIndexMatrix gmat(dmat_.get(), kMaxBins, false, common::OmpGetNumThreads(0));
|
||||
int32_t constexpr kMaxBins = 4;
|
||||
GHistIndexMatrix gmat{dmat_.get(), kMaxBins, 0.0f, false, common::OmpGetNumThreads(0)};
|
||||
|
||||
RegTree tree = RegTree();
|
||||
tree.param.UpdateAllowUnknown(cfg_);
|
||||
@@ -264,8 +265,8 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
}
|
||||
|
||||
void TestInitDataSampling() {
|
||||
size_t constexpr kMaxBins = 4;
|
||||
GHistIndexMatrix gmat(dmat_.get(), kMaxBins, false, common::OmpGetNumThreads(0));
|
||||
int32_t constexpr kMaxBins = 4;
|
||||
GHistIndexMatrix gmat{dmat_.get(), kMaxBins, 0.0f, false, common::OmpGetNumThreads(0)};
|
||||
|
||||
RegTree tree = RegTree();
|
||||
tree.param.UpdateAllowUnknown(cfg_);
|
||||
|
||||
Reference in New Issue
Block a user