Refactor tests with data generator. (#5439)

This commit is contained in:
Jiaming Yuan
2020-03-27 06:44:44 +08:00
committed by GitHub
parent 7146b91d5a
commit 4942da64ae
26 changed files with 334 additions and 259 deletions

View File

@@ -139,23 +139,23 @@ class QuantileHistMock : public QuantileHistMaker {
{ {1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f},
{0.27f, 0.29f}, {0.37f, 0.39f}, {-0.47f, 0.49f}, {0.57f, 0.59f} };
size_t constexpr kMaxBins = 4;
auto dmat = CreateDMatrix(kNRows, kNCols, 0, 3);
auto dmat = RandomDataGenerator(kNRows, kNCols, 0).Seed(3).GenerateDMatix();
// dense, no missing values
common::GHistIndexMatrix gmat;
gmat.Init((*dmat).get(), kMaxBins);
gmat.Init(dmat.get(), kMaxBins);
RealImpl::InitData(gmat, row_gpairs, *(*dmat), tree);
RealImpl::InitData(gmat, row_gpairs, *dmat, tree);
hist_.AddHistRow(0);
BuildHist(row_gpairs, row_set_collection_[0],
gmat, quantile_index_block, hist_[0]);
RealImpl::InitNewNode(0, gmat, row_gpairs, *(*dmat), tree);
RealImpl::InitNewNode(0, gmat, row_gpairs, *dmat, tree);
/* Compute correct split (best_split) using the computed histogram */
const size_t num_row = dmat->get()->Info().num_row_;
const size_t num_feature = dmat->get()->Info().num_col_;
const size_t num_row = dmat->Info().num_row_;
const size_t num_feature = dmat->Info().num_col_;
CHECK_EQ(num_row, row_gpairs.size());
// Compute total gradient for all data points
GradientPairPrecise total_gpair;
@@ -216,8 +216,6 @@ class QuantileHistMock : public QuantileHistMaker {
RealImpl::EvaluateSplits({node}, gmat, hist_, tree);
ASSERT_EQ(snode_[0].best.SplitIndex(), best_split_feature);
ASSERT_EQ(snode_[0].best.split_value, gmat.cut.Values()[best_split_threshold]);
delete dmat;
}
void TestEvaluateSplitParallel(const GHistIndexBlockMatrix &quantile_index_block,
@@ -230,7 +228,7 @@ class QuantileHistMock : public QuantileHistMaker {
};
int static constexpr kNRows = 8, kNCols = 16;
std::shared_ptr<xgboost::DMatrix> *dmat_;
std::shared_ptr<xgboost::DMatrix> dmat_;
const std::vector<std::pair<std::string, std::string> > cfg_;
std::shared_ptr<BuilderMock> builder_;
@@ -240,23 +238,23 @@ class QuantileHistMock : public QuantileHistMaker {
cfg_{args} {
QuantileHistMaker::Configure(args);
spliteval_->Init(&param_);
dmat_ = CreateDMatrix(kNRows, kNCols, 0.8, 3);
dmat_ = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatix();
builder_.reset(
new BuilderMock(
param_,
std::move(pruner_),
std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()),
int_constraint_,
dmat_->get()));
dmat_.get()));
}
~QuantileHistMock() override { delete dmat_; }
~QuantileHistMock() override = default;
static size_t GetNumColumns() { return kNCols; }
void TestInitData() {
size_t constexpr kMaxBins = 4;
common::GHistIndexMatrix gmat;
gmat.Init((*dmat_).get(), kMaxBins);
gmat.Init(dmat_.get(), kMaxBins);
RegTree tree = RegTree();
tree.param.UpdateAllowUnknown(cfg_);
@@ -265,7 +263,7 @@ class QuantileHistMock : public QuantileHistMaker {
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
builder_->TestInitData(gmat, gpair, dmat_->get(), tree);
builder_->TestInitData(gmat, gpair, dmat_.get(), tree);
}
void TestBuildHist() {
@@ -274,9 +272,9 @@ class QuantileHistMock : public QuantileHistMaker {
size_t constexpr kMaxBins = 4;
common::GHistIndexMatrix gmat;
gmat.Init((*dmat_).get(), kMaxBins);
gmat.Init(dmat_.get(), kMaxBins);
builder_->TestBuildHist(0, gmat, *(*dmat_).get(), tree);
builder_->TestBuildHist(0, gmat, *dmat_, tree);
}
void TestEvaluateSplit() {