Reduce 'InitSampling' complexity and set gradients to zero (#6922)
Co-authored-by: Kirill Shvets <kirill.shvets@intel.com>
This commit is contained in:
@@ -31,15 +31,15 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
std::unique_ptr<TreeUpdater> pruner,
|
||||
FeatureInteractionConstraintHost int_constraint,
|
||||
DMatrix const* fmat)
|
||||
: RealImpl(param, std::move(pruner),
|
||||
: RealImpl(1, param, std::move(pruner),
|
||||
std::move(int_constraint), fmat) {}
|
||||
|
||||
public:
|
||||
void TestInitData(const GHistIndexMatrix& gmat,
|
||||
const std::vector<GradientPair>& gpair,
|
||||
std::vector<GradientPair>* gpair,
|
||||
DMatrix* p_fmat,
|
||||
const RegTree& tree) {
|
||||
RealImpl::InitData(gmat, gpair, *p_fmat, tree);
|
||||
RealImpl::InitData(gmat, *p_fmat, tree, gpair);
|
||||
ASSERT_EQ(this->data_layout_, RealImpl::DataLayout::kSparseData);
|
||||
|
||||
/* The creation of HistCutMatrix and GHistIndexMatrix are not technically
|
||||
@@ -101,29 +101,34 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
}
|
||||
|
||||
void TestInitDataSampling(const GHistIndexMatrix& gmat,
|
||||
const std::vector<GradientPair>& gpair,
|
||||
std::vector<GradientPair>* gpair,
|
||||
DMatrix* p_fmat,
|
||||
const RegTree& tree) {
|
||||
// check SimpleSkip
|
||||
size_t initial_seed = 777;
|
||||
std::linear_congruential_engine<std::uint_fast64_t, 16807, 0,
|
||||
static_cast<uint64_t>(1) << 63 > eng_first(initial_seed);
|
||||
for (size_t i = 0; i < 100; ++i) {
|
||||
eng_first();
|
||||
}
|
||||
uint64_t initial_seed_th = RandomReplace::SimpleSkip(100, initial_seed, 16807, RandomReplace::kMod);
|
||||
std::linear_congruential_engine<std::uint_fast64_t, RandomReplace::kBase, 0,
|
||||
RandomReplace::kMod > eng_second(initial_seed_th);
|
||||
ASSERT_EQ(eng_first(), eng_second());
|
||||
|
||||
const size_t nthreads = omp_get_num_threads();
|
||||
// save state of global rng engine
|
||||
auto initial_rnd = common::GlobalRandom();
|
||||
std::vector<size_t> unused_rows_cpy = this->unused_rows_;
|
||||
RealImpl::InitData(gmat, gpair, *p_fmat, tree);
|
||||
RealImpl::InitData(gmat, *p_fmat, tree, gpair);
|
||||
std::vector<size_t> row_indices_initial = *(this->row_set_collection_.Data());
|
||||
std::vector<size_t> unused_row_indices_initial = this->unused_rows_;
|
||||
ASSERT_EQ(row_indices_initial.size(), p_fmat->Info().num_row_);
|
||||
auto check_each_row_occurs_in_one_of_arrays = [](const std::vector<size_t>& first,
|
||||
const std::vector<size_t>& second,
|
||||
size_t nrows) {
|
||||
std::vector<size_t> arr_union(nrows);
|
||||
for (auto&& row_indice : first) {
|
||||
++arr_union[row_indice];
|
||||
}
|
||||
for (auto&& row_indice : second) {
|
||||
++arr_union[row_indice];
|
||||
}
|
||||
for (auto&& row_cnt : arr_union) {
|
||||
ASSERT_EQ(row_cnt, 1ul);
|
||||
}
|
||||
ASSERT_EQ(first.size(), nrows);
|
||||
ASSERT_EQ(second.size(), 0);
|
||||
};
|
||||
check_each_row_occurs_in_one_of_arrays(row_indices_initial, unused_row_indices_initial,
|
||||
p_fmat->Info().num_row_);
|
||||
@@ -133,7 +138,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
// return initial state of global rng engine
|
||||
common::GlobalRandom() = initial_rnd;
|
||||
this->unused_rows_ = unused_rows_cpy;
|
||||
RealImpl::InitData(gmat, gpair, *p_fmat, tree);
|
||||
RealImpl::InitData(gmat, *p_fmat, tree, gpair);
|
||||
std::vector<size_t>& row_indices = *(this->row_set_collection_.Data());
|
||||
ASSERT_EQ(row_indices_initial.size(), row_indices.size());
|
||||
for (size_t i = 0; i < row_indices_initial.size(); ++i) {
|
||||
@@ -151,10 +156,10 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
}
|
||||
|
||||
void TestAddHistRows(const GHistIndexMatrix& gmat,
|
||||
const std::vector<GradientPair>& gpair,
|
||||
std::vector<GradientPair>* gpair,
|
||||
DMatrix* p_fmat,
|
||||
RegTree* tree) {
|
||||
RealImpl::InitData(gmat, gpair, *p_fmat, *tree);
|
||||
RealImpl::InitData(gmat, *p_fmat, *tree, gpair);
|
||||
|
||||
int starting_index = std::numeric_limits<int>::max();
|
||||
int sync_count = 0;
|
||||
@@ -183,11 +188,11 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
|
||||
|
||||
void TestSyncHistograms(const GHistIndexMatrix& gmat,
|
||||
const std::vector<GradientPair>& gpair,
|
||||
std::vector<GradientPair>* gpair,
|
||||
DMatrix* p_fmat,
|
||||
RegTree* tree) {
|
||||
// init
|
||||
RealImpl::InitData(gmat, gpair, *p_fmat, *tree);
|
||||
RealImpl::InitData(gmat, *p_fmat, *tree, gpair);
|
||||
|
||||
int starting_index = std::numeric_limits<int>::max();
|
||||
int sync_count = 0;
|
||||
@@ -295,10 +300,10 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
const GHistIndexMatrix& gmat,
|
||||
const DMatrix& fmat,
|
||||
const RegTree& tree) {
|
||||
const std::vector<GradientPair> gpair =
|
||||
std::vector<GradientPair> gpair =
|
||||
{ {0.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {0.27f, 0.28f},
|
||||
{0.27f, 0.29f}, {0.37f, 0.39f}, {0.47f, 0.49f}, {0.57f, 0.59f} };
|
||||
RealImpl::InitData(gmat, gpair, fmat, tree);
|
||||
RealImpl::InitData(gmat, fmat, tree, &gpair);
|
||||
GHistIndexBlockMatrix dummy;
|
||||
this->hist_.AddHistRow(nid);
|
||||
this->hist_.AllocateAllData();
|
||||
@@ -341,7 +346,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
common::GHistIndexMatrix gmat;
|
||||
gmat.Init(dmat.get(), kMaxBins);
|
||||
|
||||
RealImpl::InitData(gmat, row_gpairs, *dmat, tree);
|
||||
RealImpl::InitData(gmat, *dmat, tree, &row_gpairs);
|
||||
this->hist_.AddHistRow(0);
|
||||
this->hist_.AllocateAllData();
|
||||
this->BuildHist(row_gpairs, this->row_set_collection_[0],
|
||||
@@ -437,7 +442,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
|
||||
// treat everything as dense, as this is what we intend to test here
|
||||
cm.Init(gmat, 0.0);
|
||||
RealImpl::InitData(gmat, row_gpairs, *dmat, tree);
|
||||
RealImpl::InitData(gmat, *dmat, tree, &row_gpairs);
|
||||
this->hist_.AddHistRow(0);
|
||||
this->hist_.AllocateAllData();
|
||||
RealImpl::InitNewNode(0, gmat, row_gpairs, *dmat, tree);
|
||||
@@ -548,9 +553,9 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
|
||||
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
|
||||
if (double_builder_) {
|
||||
double_builder_->TestInitData(gmat, gpair, dmat_.get(), tree);
|
||||
double_builder_->TestInitData(gmat, &gpair, dmat_.get(), tree);
|
||||
} else {
|
||||
float_builder_->TestInitData(gmat, gpair, dmat_.get(), tree);
|
||||
float_builder_->TestInitData(gmat, &gpair, dmat_.get(), tree);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -566,9 +571,9 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
|
||||
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
|
||||
if (double_builder_) {
|
||||
double_builder_->TestInitDataSampling(gmat, gpair, dmat_.get(), tree);
|
||||
double_builder_->TestInitDataSampling(gmat, &gpair, dmat_.get(), tree);
|
||||
} else {
|
||||
float_builder_->TestInitDataSampling(gmat, gpair, dmat_.get(), tree);
|
||||
float_builder_->TestInitDataSampling(gmat, &gpair, dmat_.get(), tree);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -583,9 +588,9 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
|
||||
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
|
||||
if (double_builder_) {
|
||||
double_builder_->TestAddHistRows(gmat, gpair, dmat_.get(), &tree);
|
||||
double_builder_->TestAddHistRows(gmat, &gpair, dmat_.get(), &tree);
|
||||
} else {
|
||||
float_builder_->TestAddHistRows(gmat, gpair, dmat_.get(), &tree);
|
||||
float_builder_->TestAddHistRows(gmat, &gpair, dmat_.get(), &tree);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -600,9 +605,9 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
|
||||
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
|
||||
if (double_builder_) {
|
||||
double_builder_->TestSyncHistograms(gmat, gpair, dmat_.get(), &tree);
|
||||
double_builder_->TestSyncHistograms(gmat, &gpair, dmat_.get(), &tree);
|
||||
} else {
|
||||
float_builder_->TestSyncHistograms(gmat, gpair, dmat_.get(), &tree);
|
||||
float_builder_->TestSyncHistograms(gmat, &gpair, dmat_.get(), &tree);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user