Optimizations for RNG in InitData kernel (#5522)
* optimizations for subsampling in InitData * optimizations for subsampling in InitData Co-authored-by: SHVETS, KIRILL <kirill.shvets@intel.com>
This commit is contained in:
@@ -96,6 +96,31 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
}
|
||||
}
|
||||
|
||||
void TestInitDataSampling(const GHistIndexMatrix& gmat,
|
||||
const std::vector<GradientPair>& gpair,
|
||||
DMatrix* p_fmat,
|
||||
const RegTree& tree) {
|
||||
const size_t nthreads = omp_get_num_threads();
|
||||
// save state of global rng engine
|
||||
auto initial_rnd = common::GlobalRandom();
|
||||
RealImpl::InitData(gmat, gpair, *p_fmat, tree);
|
||||
std::vector<size_t> row_indices_initial = *row_set_collection_.Data();
|
||||
|
||||
for (size_t i_nthreads = 1; i_nthreads < 4; ++i_nthreads) {
|
||||
omp_set_num_threads(i_nthreads);
|
||||
// return initial state of global rng engine
|
||||
common::GlobalRandom() = initial_rnd;
|
||||
RealImpl::InitData(gmat, gpair, *p_fmat, tree);
|
||||
std::vector<size_t>& row_indices = *row_set_collection_.Data();
|
||||
ASSERT_EQ(row_indices_initial.size(), row_indices.size());
|
||||
for (size_t i = 0; i < row_indices_initial.size(); ++i) {
|
||||
ASSERT_EQ(row_indices_initial[i], row_indices[i]);
|
||||
}
|
||||
}
|
||||
omp_set_num_threads(nthreads);
|
||||
}
|
||||
|
||||
|
||||
void TestBuildHist(int nid,
|
||||
const GHistIndexMatrix& gmat,
|
||||
const DMatrix& fmat,
|
||||
@@ -266,6 +291,20 @@ class QuantileHistMock : public QuantileHistMaker {
|
||||
builder_->TestInitData(gmat, gpair, dmat_.get(), tree);
|
||||
}
|
||||
|
||||
void TestInitDataSampling() {
|
||||
size_t constexpr kMaxBins = 4;
|
||||
common::GHistIndexMatrix gmat;
|
||||
gmat.Init(dmat_.get(), kMaxBins);
|
||||
|
||||
RegTree tree = RegTree();
|
||||
tree.param.UpdateAllowUnknown(cfg_);
|
||||
|
||||
std::vector<GradientPair> gpair =
|
||||
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
|
||||
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
|
||||
|
||||
builder_->TestInitDataSampling(gmat, gpair, dmat_.get(), tree);
|
||||
}
|
||||
void TestBuildHist() {
|
||||
RegTree tree = RegTree();
|
||||
tree.param.UpdateAllowUnknown(cfg_);
|
||||
@@ -292,6 +331,15 @@ TEST(QuantileHist, InitData) {
|
||||
maker.TestInitData();
|
||||
}
|
||||
|
||||
TEST(QuantileHist, InitDataSampling) {
|
||||
const float subsample = 0.5;
|
||||
std::vector<std::pair<std::string, std::string>> cfg
|
||||
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())},
|
||||
{"subsample", std::to_string(subsample)}};
|
||||
QuantileHistMock maker(cfg);
|
||||
maker.TestInitDataSampling();
|
||||
}
|
||||
|
||||
TEST(QuantileHist, BuildHist) {
|
||||
// Don't enable feature grouping
|
||||
std::vector<std::pair<std::string, std::string>> cfg
|
||||
|
||||
Reference in New Issue
Block a user