Add float32 histogram (#5624)

* new single_precision_histogram param was added.

Co-authored-by: SHVETS, KIRILL <kirill.shvets@intel.com>
Co-authored-by: fis <jm.yuan@outlook.com>
This commit is contained in:
ShvetsKS
2020-06-03 06:24:53 +03:00
committed by GitHub
parent e49607af19
commit cd3d14ad0e
10 changed files with 618 additions and 303 deletions

View File

@@ -20,8 +20,8 @@ size_t GetNThreads() {
return nthreads;
}
TEST(ParallelGHistBuilder, Reset) {
template <typename GradientSumT>
void ParallelGHistBuilderReset() {
constexpr size_t kBins = 10;
constexpr size_t kNodes = 5;
constexpr size_t kNodesExtended = 10;
@@ -29,16 +29,16 @@ TEST(ParallelGHistBuilder, Reset) {
constexpr double kValue = 1.0;
const size_t nthreads = GetNThreads();
HistCollection collection;
HistCollection<GradientSumT> collection;
collection.Init(kBins);
for(size_t inode = 0; inode < kNodesExtended; inode++) {
collection.AddHistRow(inode);
}
ParallelGHistBuilder hist_builder;
ParallelGHistBuilder<GradientSumT> hist_builder;
hist_builder.Init(kBins);
std::vector<GHistRow> target_hist(kNodes);
std::vector<GHistRow<GradientSumT>> target_hist(kNodes);
for(size_t i = 0; i < target_hist.size(); ++i) {
target_hist[i] = collection[i];
}
@@ -49,7 +49,7 @@ TEST(ParallelGHistBuilder, Reset) {
common::ParallelFor2d(space, nthreads, [&](size_t inode, common::Range1d r) {
const size_t tid = omp_get_thread_num();
GHistRow hist = hist_builder.GetInitializedHist(tid, inode);
GHistRow<GradientSumT> hist = hist_builder.GetInitializedHist(tid, inode);
// fill hist by some non-null values
for(size_t j = 0; j < kBins; ++j) {
hist[j].Add(kValue, kValue);
@@ -67,7 +67,7 @@ TEST(ParallelGHistBuilder, Reset) {
common::ParallelFor2d(space2, nthreads, [&](size_t inode, common::Range1d r) {
const size_t tid = omp_get_thread_num();
GHistRow hist = hist_builder.GetInitializedHist(tid, inode);
GHistRow<GradientSumT> hist = hist_builder.GetInitializedHist(tid, inode);
// fill hist by some non-null values
for(size_t j = 0; j < kBins; ++j) {
ASSERT_EQ(0.0, hist[j].GetGrad());
@@ -76,23 +76,25 @@ TEST(ParallelGHistBuilder, Reset) {
});
}
TEST(ParallelGHistBuilder, ReduceHist) {
template <typename GradientSumT>
void ParallelGHistBuilderReduceHist(){
constexpr size_t kBins = 10;
constexpr size_t kNodes = 5;
constexpr size_t kTasksPerNode = 10;
constexpr double kValue = 1.0;
const size_t nthreads = GetNThreads();
HistCollection collection;
HistCollection<GradientSumT> collection;
collection.Init(kBins);
for(size_t inode = 0; inode < kNodes; inode++) {
collection.AddHistRow(inode);
}
ParallelGHistBuilder hist_builder;
ParallelGHistBuilder<GradientSumT> hist_builder;
hist_builder.Init(kBins);
std::vector<GHistRow> target_hist(kNodes);
std::vector<GHistRow<GradientSumT>> target_hist(kNodes);
for(size_t i = 0; i < target_hist.size(); ++i) {
target_hist[i] = collection[i];
}
@@ -104,7 +106,7 @@ TEST(ParallelGHistBuilder, ReduceHist) {
common::ParallelFor2d(space, nthreads, [&](size_t inode, common::Range1d r) {
const size_t tid = omp_get_thread_num();
GHistRow hist = hist_builder.GetInitializedHist(tid, inode);
GHistRow<GradientSumT> hist = hist_builder.GetInitializedHist(tid, inode);
for(size_t i = 0; i < kBins; ++i) {
hist[i].Add(kValue, kValue);
}
@@ -122,6 +124,21 @@ TEST(ParallelGHistBuilder, ReduceHist) {
}
}
TEST(ParallelGHistBuilder, ResetDouble) {
ParallelGHistBuilderReset<double>();
}
TEST(ParallelGHistBuilder, ResetFloat) {
ParallelGHistBuilderReset<float>();
}
TEST(ParallelGHistBuilder, ReduceHistDouble) {
ParallelGHistBuilderReduceHist<double>();
}
TEST(ParallelGHistBuilder, ReduceHistFloat) {
ParallelGHistBuilderReduceHist<float>();
}
TEST(CutsBuilder, SearchGroupInd) {
size_t constexpr kNumGroups = 4;

View File

@@ -21,8 +21,11 @@ namespace tree {
class QuantileHistMock : public QuantileHistMaker {
static double constexpr kEps = 1e-6;
struct BuilderMock : public QuantileHistMaker::Builder {
using RealImpl = QuantileHistMaker::Builder;
template <typename GradientSumT>
struct BuilderMock : public QuantileHistMaker::Builder<GradientSumT> {
using RealImpl = QuantileHistMaker::Builder<GradientSumT>;
using ExpandEntryT = typename RealImpl::ExpandEntry;
using GHistRowT = typename RealImpl::GHistRowT;
BuilderMock(const TrainParam& param,
std::unique_ptr<TreeUpdater> pruner,
@@ -30,7 +33,7 @@ class QuantileHistMock : public QuantileHistMaker {
FeatureInteractionConstraintHost int_constraint,
DMatrix const* fmat)
: RealImpl(param, std::move(pruner), std::move(spliteval),
std::move(int_constraint), fmat) {}
std::move(int_constraint), fmat) {}
public:
void TestInitData(const GHistIndexMatrix& gmat,
@@ -38,7 +41,7 @@ class QuantileHistMock : public QuantileHistMaker {
DMatrix* p_fmat,
const RegTree& tree) {
RealImpl::InitData(gmat, gpair, *p_fmat, tree);
ASSERT_EQ(data_layout_, kSparseData);
ASSERT_EQ(this->data_layout_, RealImpl::kSparseData);
/* The creation of HistCutMatrix and GHistIndexMatrix are not technically
* part of QuantileHist updater logic, but we include it here because
@@ -105,14 +108,14 @@ class QuantileHistMock : public QuantileHistMaker {
// save state of global rng engine
auto initial_rnd = common::GlobalRandom();
RealImpl::InitData(gmat, gpair, *p_fmat, tree);
std::vector<size_t> row_indices_initial = *row_set_collection_.Data();
std::vector<size_t> row_indices_initial = *(this->row_set_collection_.Data());
for (size_t i_nthreads = 1; i_nthreads < 4; ++i_nthreads) {
omp_set_num_threads(i_nthreads);
// return initial state of global rng engine
common::GlobalRandom() = initial_rnd;
RealImpl::InitData(gmat, gpair, *p_fmat, tree);
std::vector<size_t>& row_indices = *row_set_collection_.Data();
std::vector<size_t>& row_indices = *(this->row_set_collection_.Data());
ASSERT_EQ(row_indices_initial.size(), row_indices.size());
for (size_t i = 0; i < row_indices_initial.size(); ++i) {
ASSERT_EQ(row_indices_initial[i], row_indices[i]);
@@ -129,26 +132,26 @@ class QuantileHistMock : public QuantileHistMaker {
int starting_index = std::numeric_limits<int>::max();
int sync_count = 0;
nodes_for_explicit_hist_build_.clear();
nodes_for_subtraction_trick_.clear();
this->nodes_for_explicit_hist_build_.clear();
this->nodes_for_subtraction_trick_.clear();
tree->ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
tree->ExpandNode((*tree)[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
tree->ExpandNode((*tree)[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
nodes_for_explicit_hist_build_.emplace_back(3, 4, tree->GetDepth(3), 0.0f, 0);
nodes_for_explicit_hist_build_.emplace_back(4, 3, tree->GetDepth(4), 0.0f, 0);
nodes_for_subtraction_trick_.emplace_back(5, 6, tree->GetDepth(5), 0.0f, 0);
nodes_for_subtraction_trick_.emplace_back(6, 5, tree->GetDepth(6), 0.0f, 0);
this->nodes_for_explicit_hist_build_.emplace_back(3, 4, tree->GetDepth(3), 0.0f, 0);
this->nodes_for_explicit_hist_build_.emplace_back(4, 3, tree->GetDepth(4), 0.0f, 0);
this->nodes_for_subtraction_trick_.emplace_back(5, 6, tree->GetDepth(5), 0.0f, 0);
this->nodes_for_subtraction_trick_.emplace_back(6, 5, tree->GetDepth(6), 0.0f, 0);
hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
ASSERT_EQ(sync_count, 2);
ASSERT_EQ(starting_index, 3);
for (const ExpandEntry& node : nodes_for_explicit_hist_build_) {
ASSERT_EQ(hist_.RowExists(node.nid), true);
for (const ExpandEntryT& node : this->nodes_for_explicit_hist_build_) {
ASSERT_EQ(this->hist_.RowExists(node.nid), true);
}
for (const ExpandEntry& node : nodes_for_subtraction_trick_) {
ASSERT_EQ(hist_.RowExists(node.nid), true);
for (const ExpandEntryT& node : this->nodes_for_subtraction_trick_) {
ASSERT_EQ(this->hist_.RowExists(node.nid), true);
}
}
@@ -162,60 +165,61 @@ class QuantileHistMock : public QuantileHistMaker {
int starting_index = std::numeric_limits<int>::max();
int sync_count = 0;
nodes_for_explicit_hist_build_.clear();
nodes_for_subtraction_trick_.clear();
this->nodes_for_explicit_hist_build_.clear();
this->nodes_for_subtraction_trick_.clear();
// level 0
nodes_for_explicit_hist_build_.emplace_back(0, -1, tree->GetDepth(0), 0.0f, 0);
hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
this->nodes_for_explicit_hist_build_.emplace_back(0, -1, tree->GetDepth(0), 0.0f, 0);
this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
tree->ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
nodes_for_explicit_hist_build_.clear();
nodes_for_subtraction_trick_.clear();
this->nodes_for_explicit_hist_build_.clear();
this->nodes_for_subtraction_trick_.clear();
// level 1
nodes_for_explicit_hist_build_.emplace_back((*tree)[0].LeftChild(), (*tree)[0].RightChild(),
this->nodes_for_explicit_hist_build_.emplace_back((*tree)[0].LeftChild(),
(*tree)[0].RightChild(),
tree->GetDepth(1), 0.0f, 0);
nodes_for_subtraction_trick_.emplace_back((*tree)[0].RightChild(), (*tree)[0].LeftChild(),
this->nodes_for_subtraction_trick_.emplace_back((*tree)[0].RightChild(),
(*tree)[0].LeftChild(),
tree->GetDepth(2), 0.0f, 0);
hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
tree->ExpandNode((*tree)[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
tree->ExpandNode((*tree)[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
nodes_for_explicit_hist_build_.clear();
nodes_for_subtraction_trick_.clear();
this->nodes_for_explicit_hist_build_.clear();
this->nodes_for_subtraction_trick_.clear();
// level 2
nodes_for_explicit_hist_build_.emplace_back(3, 4, tree->GetDepth(3), 0.0f, 0);
nodes_for_subtraction_trick_.emplace_back(4, 3, tree->GetDepth(4), 0.0f, 0);
nodes_for_explicit_hist_build_.emplace_back(5, 6, tree->GetDepth(5), 0.0f, 0);
nodes_for_subtraction_trick_.emplace_back(6, 5, tree->GetDepth(6), 0.0f, 0);
hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
this->nodes_for_explicit_hist_build_.emplace_back(3, 4, tree->GetDepth(3), 0.0f, 0);
this->nodes_for_subtraction_trick_.emplace_back(4, 3, tree->GetDepth(4), 0.0f, 0);
this->nodes_for_explicit_hist_build_.emplace_back(5, 6, tree->GetDepth(5), 0.0f, 0);
this->nodes_for_subtraction_trick_.emplace_back(6, 5, tree->GetDepth(6), 0.0f, 0);
this->hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
const size_t n_nodes = nodes_for_explicit_hist_build_.size();
const size_t n_nodes = this->nodes_for_explicit_hist_build_.size();
ASSERT_EQ(n_nodes, 2);
row_set_collection_.AddSplit(0, (*tree)[0].LeftChild(),
this->row_set_collection_.AddSplit(0, (*tree)[0].LeftChild(),
(*tree)[0].RightChild(), 4, 4);
row_set_collection_.AddSplit(1, (*tree)[1].LeftChild(),
this->row_set_collection_.AddSplit(1, (*tree)[1].LeftChild(),
(*tree)[1].RightChild(), 2, 2);
row_set_collection_.AddSplit(2, (*tree)[2].LeftChild(),
this->row_set_collection_.AddSplit(2, (*tree)[2].LeftChild(),
(*tree)[2].RightChild(), 2, 2);
common::BlockedSpace2d space(n_nodes, [&](size_t node) {
const int32_t nid = nodes_for_explicit_hist_build_[node].nid;
return row_set_collection_[nid].Size();
const int32_t nid = this->nodes_for_explicit_hist_build_[node].nid;
return this->row_set_collection_[nid].Size();
}, 256);
std::vector<GHistRow> target_hists(n_nodes);
for (size_t i = 0; i < nodes_for_explicit_hist_build_.size(); ++i) {
const int32_t nid = nodes_for_explicit_hist_build_[i].nid;
target_hists[i] = hist_[nid];
std::vector<GHistRowT> target_hists(n_nodes);
for (size_t i = 0; i < this->nodes_for_explicit_hist_build_.size(); ++i) {
const int32_t nid = this->nodes_for_explicit_hist_build_[i].nid;
target_hists[i] = this->hist_[nid];
}
const size_t nbins = hist_builder_.GetNumBins();
const size_t nbins = this->hist_builder_.GetNumBins();
// set values to specific nodes hist
std::vector<size_t> n_ids = {1, 2};
for (size_t i : n_ids) {
auto this_hist = hist_[i];
using FPType = decltype(tree::GradStats::sum_grad);
FPType* p_hist = reinterpret_cast<FPType*>(this_hist.data());
auto this_hist = this->hist_[i];
GradientSumT* p_hist = reinterpret_cast<GradientSumT*>(this_hist.data());
for (size_t bin_id = 0; bin_id < 2*nbins; ++bin_id) {
p_hist[bin_id] = 2*bin_id;
}
@@ -223,41 +227,39 @@ class QuantileHistMock : public QuantileHistMaker {
n_ids[0] = 3;
n_ids[1] = 5;
for (size_t i : n_ids) {
auto this_hist = hist_[i];
using FPType = decltype(tree::GradStats::sum_grad);
FPType* p_hist = reinterpret_cast<FPType*>(this_hist.data());
auto this_hist = this->hist_[i];
GradientSumT* p_hist = reinterpret_cast<GradientSumT*>(this_hist.data());
for (size_t bin_id = 0; bin_id < 2*nbins; ++bin_id) {
p_hist[bin_id] = bin_id;
}
}
hist_buffer_.Reset(1, n_nodes, space, target_hists);
this->hist_buffer_.Reset(1, n_nodes, space, target_hists);
// sync hist
hist_synchronizer_->SyncHistograms(this, starting_index, sync_count, tree);
this->hist_synchronizer_->SyncHistograms(this, starting_index, sync_count, tree);
auto check_hist = [] (const GHistRow parent, const GHistRow left,
const GHistRow right, size_t begin, size_t end) {
using FPType = decltype(tree::GradStats::sum_grad);
const FPType* p_parent = reinterpret_cast<const FPType*>(parent.data());
const FPType* p_left = reinterpret_cast<const FPType*>(left.data());
const FPType* p_right = reinterpret_cast<const FPType*>(right.data());
auto check_hist = [] (const GHistRowT parent, const GHistRowT left,
const GHistRowT right, size_t begin, size_t end) {
const GradientSumT* p_parent = reinterpret_cast<const GradientSumT*>(parent.data());
const GradientSumT* p_left = reinterpret_cast<const GradientSumT*>(left.data());
const GradientSumT* p_right = reinterpret_cast<const GradientSumT*>(right.data());
for (size_t i = 2 * begin; i < 2 * end; ++i) {
ASSERT_EQ(p_parent[i], p_left[i] + p_right[i]);
}
};
for (const ExpandEntry& node : nodes_for_explicit_hist_build_) {
auto this_hist = hist_[node.nid];
for (const ExpandEntryT& node : this->nodes_for_explicit_hist_build_) {
auto this_hist = this->hist_[node.nid];
const size_t parent_id = (*tree)[node.nid].Parent();
auto parent_hist = hist_[parent_id];
auto sibling_hist = hist_[node.sibling_nid];
auto parent_hist = this->hist_[parent_id];
auto sibling_hist = this->hist_[node.sibling_nid];
check_hist(parent_hist, this_hist, sibling_hist, 0, nbins);
}
for (const ExpandEntry& node : nodes_for_subtraction_trick_) {
auto this_hist = hist_[node.nid];
for (const ExpandEntryT& node : this->nodes_for_subtraction_trick_) {
auto this_hist = this->hist_[node.nid];
const size_t parent_id = (*tree)[node.nid].Parent();
auto parent_hist = hist_[parent_id];
auto sibling_hist = hist_[node.sibling_nid];
auto parent_hist = this->hist_[parent_id];
auto sibling_hist = this->hist_[node.sibling_nid];
check_hist(parent_hist, this_hist, sibling_hist, 0, nbins);
}
@@ -272,13 +274,13 @@ class QuantileHistMock : public QuantileHistMaker {
{0.27f, 0.29f}, {0.37f, 0.39f}, {0.47f, 0.49f}, {0.57f, 0.59f} };
RealImpl::InitData(gmat, gpair, fmat, tree);
GHistIndexBlockMatrix dummy;
hist_.AddHistRow(nid);
BuildHist(gpair, row_set_collection_[nid],
gmat, dummy, hist_[nid]);
this->hist_.AddHistRow(nid);
this->BuildHist(gpair, this->row_set_collection_[nid],
gmat, dummy, this->hist_[nid]);
// Check if number of histogram bins is correct
ASSERT_EQ(hist_[nid].size(), gmat.cut.Ptrs().back());
std::vector<GradientPairPrecise> histogram_expected(hist_[nid].size());
ASSERT_EQ(this->hist_[nid].size(), gmat.cut.Ptrs().back());
std::vector<GradientPairPrecise> histogram_expected(this->hist_[nid].size());
// Compute the correct histogram (histogram_expected)
const size_t num_row = fmat.Info().num_row_;
@@ -293,10 +295,10 @@ class QuantileHistMock : public QuantileHistMaker {
}
// Now validate the computed histogram returned by BuildHist
for (size_t i = 0; i < hist_[nid].size(); ++i) {
for (size_t i = 0; i < this->hist_[nid].size(); ++i) {
GradientPairPrecise sol = histogram_expected[i];
ASSERT_NEAR(sol.GetGrad(), hist_[nid][i].GetGrad(), kEps);
ASSERT_NEAR(sol.GetHess(), hist_[nid][i].GetHess(), kEps);
ASSERT_NEAR(sol.GetGrad(), this->hist_[nid][i].GetGrad(), kEps);
ASSERT_NEAR(sol.GetHess(), this->hist_[nid][i].GetHess(), kEps);
}
}
@@ -313,10 +315,10 @@ class QuantileHistMock : public QuantileHistMaker {
gmat.Init(dmat.get(), kMaxBins);
RealImpl::InitData(gmat, row_gpairs, *dmat, tree);
hist_.AddHistRow(0);
this->hist_.AddHistRow(0);
BuildHist(row_gpairs, row_set_collection_[0],
gmat, quantile_index_block, hist_[0]);
this->BuildHist(row_gpairs, this->row_set_collection_[0],
gmat, quantile_index_block, this->hist_[0]);
RealImpl::InitNewNode(0, gmat, row_gpairs, *dmat, tree);
@@ -331,7 +333,7 @@ class QuantileHistMock : public QuantileHistMaker {
}
// Initialize split evaluator
std::unique_ptr<SplitEvaluator> evaluator(SplitEvaluator::Create("elastic_net"));
evaluator->Init(&param_);
evaluator->Init(&this->param_);
// Now enumerate all feature*threshold combination to get best split
// To simplify logic, we make some assumptions:
@@ -378,11 +380,13 @@ class QuantileHistMock : public QuantileHistMaker {
}
/* Now compare against result given by EvaluateSplit() */
ExpandEntry node(ExpandEntry::kRootNid, ExpandEntry::kEmptyNid,
tree.GetDepth(0), snode_[0].best.loss_chg, 0);
RealImpl::EvaluateSplits({node}, gmat, hist_, tree);
ASSERT_EQ(snode_[0].best.SplitIndex(), best_split_feature);
ASSERT_EQ(snode_[0].best.split_value, gmat.cut.Values()[best_split_threshold]);
typename RealImpl::ExpandEntry node(RealImpl::ExpandEntry::kRootNid,
RealImpl::ExpandEntry::kEmptyNid,
tree.GetDepth(0),
this->snode_[0].best.loss_chg, 0);
RealImpl::EvaluateSplits({node}, gmat, this->hist_, tree);
ASSERT_EQ(this->snode_[0].best.SplitIndex(), best_split_feature);
ASSERT_EQ(this->snode_[0].best.split_value, gmat.cut.Values()[best_split_threshold]);
}
void TestEvaluateSplitParallel(const GHistIndexBlockMatrix &quantile_index_block,
@@ -411,7 +415,7 @@ class QuantileHistMock : public QuantileHistMaker {
// treat everything as dense, as this is what we intend to test here
cm.Init(gmat, 0.0);
RealImpl::InitData(gmat, row_gpairs, *dmat, tree);
hist_.AddHistRow(0);
this->hist_.AddHistRow(0);
RealImpl::InitNewNode(0, gmat, row_gpairs, *dmat, tree);
@@ -430,9 +434,9 @@ class QuantileHistMock : public QuantileHistMaker {
const size_t bin_id = gmat.index[offset];
if (bin_id >= bin_id_min && bin_id < bin_id_max) {
if (bin_id <= split) {
left_cnt ++;
left_cnt++;
} else {
right_cnt ++;
right_cnt++;
}
}
}
@@ -450,7 +454,8 @@ class QuantileHistMock : public QuantileHistMaker {
RealImpl::partition_builder_.Init(1, 1, [&](size_t node_in_set) {
return 1;
});
RealImpl::PartitionKernel<uint8_t>(0, 0, common::Range1d(0, kNRows), split, cm, tree);
this->template PartitionKernel<uint8_t>(0, 0, common::Range1d(0, kNRows),
split, cm, tree);
RealImpl::partition_builder_.CalculateRowOffsets();
ASSERT_EQ(RealImpl::partition_builder_.GetNLeftElems(0), left_cnt);
ASSERT_EQ(RealImpl::partition_builder_.GetNRightElems(0), right_cnt);
@@ -462,28 +467,47 @@ class QuantileHistMock : public QuantileHistMaker {
int static constexpr kNRows = 8, kNCols = 16;
std::shared_ptr<xgboost::DMatrix> dmat_;
const std::vector<std::pair<std::string, std::string> > cfg_;
std::shared_ptr<BuilderMock> builder_;
std::shared_ptr<BuilderMock<float> > float_builder_;
std::shared_ptr<BuilderMock<double> > double_builder_;
public:
explicit QuantileHistMock(
const std::vector<std::pair<std::string, std::string> >& args, bool batch = true) :
const std::vector<std::pair<std::string, std::string> >& args,
const bool single_precision_histogram = false, bool batch = true) :
cfg_{args} {
QuantileHistMaker::Configure(args);
spliteval_->Init(&param_);
dmat_ = RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix();
builder_.reset(
new BuilderMock(
param_,
std::move(pruner_),
std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()),
int_constraint_,
dmat_.get()));
if (batch) {
builder_->SetHistSynchronizer(new BatchHistSynchronizer());
builder_->SetHistRowsAdder(new BatchHistRowsAdder());
if (single_precision_histogram) {
float_builder_.reset(
new BuilderMock<float>(
param_,
std::move(pruner_),
std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()),
int_constraint_,
dmat_.get()));
if (batch) {
float_builder_->SetHistSynchronizer(new BatchHistSynchronizer<float>());
float_builder_->SetHistRowsAdder(new BatchHistRowsAdder<float>());
} else {
float_builder_->SetHistSynchronizer(new DistributedHistSynchronizer<float>());
float_builder_->SetHistRowsAdder(new DistributedHistRowsAdder<float>());
}
} else {
builder_->SetHistSynchronizer(new DistributedHistSynchronizer());
builder_->SetHistRowsAdder(new DistributedHistRowsAdder());
double_builder_.reset(
new BuilderMock<double>(
param_,
std::move(pruner_),
std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()),
int_constraint_,
dmat_.get()));
if (batch) {
double_builder_->SetHistSynchronizer(new BatchHistSynchronizer<double>());
double_builder_->SetHistRowsAdder(new BatchHistRowsAdder<double>());
} else {
double_builder_->SetHistSynchronizer(new DistributedHistSynchronizer<double>());
double_builder_->SetHistRowsAdder(new DistributedHistRowsAdder<double>());
}
}
}
~QuantileHistMock() override = default;
@@ -501,8 +525,11 @@ class QuantileHistMock : public QuantileHistMaker {
std::vector<GradientPair> gpair =
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
builder_->TestInitData(gmat, gpair, dmat_.get(), tree);
if (double_builder_) {
double_builder_->TestInitData(gmat, gpair, dmat_.get(), tree);
} else {
float_builder_->TestInitData(gmat, gpair, dmat_.get(), tree);
}
}
void TestInitDataSampling() {
@@ -516,8 +543,11 @@ class QuantileHistMock : public QuantileHistMaker {
std::vector<GradientPair> gpair =
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
builder_->TestInitDataSampling(gmat, gpair, dmat_.get(), tree);
if (double_builder_) {
double_builder_->TestInitDataSampling(gmat, gpair, dmat_.get(), tree);
} else {
float_builder_->TestInitDataSampling(gmat, gpair, dmat_.get(), tree);
}
}
void TestAddHistRows() {
@@ -530,7 +560,11 @@ class QuantileHistMock : public QuantileHistMaker {
std::vector<GradientPair> gpair =
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
builder_->TestAddHistRows(gmat, gpair, dmat_.get(), &tree);
if (double_builder_) {
double_builder_->TestAddHistRows(gmat, gpair, dmat_.get(), &tree);
} else {
float_builder_->TestAddHistRows(gmat, gpair, dmat_.get(), &tree);
}
}
void TestSyncHistograms() {
@@ -543,7 +577,11 @@ class QuantileHistMock : public QuantileHistMaker {
std::vector<GradientPair> gpair =
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
builder_->TestSyncHistograms(gmat, gpair, dmat_.get(), &tree);
if (double_builder_) {
double_builder_->TestSyncHistograms(gmat, gpair, dmat_.get(), &tree);
} else {
float_builder_->TestSyncHistograms(gmat, gpair, dmat_.get(), &tree);
}
}
@@ -554,22 +592,31 @@ class QuantileHistMock : public QuantileHistMaker {
size_t constexpr kMaxBins = 4;
common::GHistIndexMatrix gmat;
gmat.Init(dmat_.get(), kMaxBins);
builder_->TestBuildHist(0, gmat, *dmat_, tree);
if (double_builder_) {
double_builder_->TestBuildHist(0, gmat, *dmat_, tree);
} else {
float_builder_->TestBuildHist(0, gmat, *dmat_, tree);
}
}
void TestEvaluateSplit() {
RegTree tree = RegTree();
tree.param.UpdateAllowUnknown(cfg_);
builder_->TestEvaluateSplit(gmatb_, tree);
if (double_builder_) {
double_builder_->TestEvaluateSplit(gmatb_, tree);
} else {
float_builder_->TestEvaluateSplit(gmatb_, tree);
}
}
void TestApplySplit() {
RegTree tree = RegTree();
tree.param.UpdateAllowUnknown(cfg_);
builder_->TestApplySplit(gmatb_, tree);
if (double_builder_) {
double_builder_->TestApplySplit(gmatb_, tree);
} else {
float_builder_->TestEvaluateSplit(gmatb_, tree);
}
}
};
@@ -578,6 +625,9 @@ TEST(QuantileHist, InitData) {
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}};
QuantileHistMock maker(cfg);
maker.TestInitData();
const bool single_precision_histogram = true;
QuantileHistMock maker_float(cfg, single_precision_histogram);
maker_float.TestInitData();
}
TEST(QuantileHist, InitDataSampling) {
@@ -587,6 +637,9 @@ TEST(QuantileHist, InitDataSampling) {
{"subsample", std::to_string(subsample)}};
QuantileHistMock maker(cfg);
maker.TestInitDataSampling();
const bool single_precision_histogram = true;
QuantileHistMock maker_float(cfg, single_precision_histogram);
maker_float.TestInitDataSampling();
}
TEST(QuantileHist, AddHistRows) {
@@ -594,6 +647,9 @@ TEST(QuantileHist, AddHistRows) {
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}};
QuantileHistMock maker(cfg);
maker.TestAddHistRows();
const bool single_precision_histogram = true;
QuantileHistMock maker_float(cfg, single_precision_histogram);
maker_float.TestAddHistRows();
}
TEST(QuantileHist, SyncHistograms) {
@@ -601,6 +657,9 @@ TEST(QuantileHist, SyncHistograms) {
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}};
QuantileHistMock maker(cfg);
maker.TestSyncHistograms();
const bool single_precision_histogram = true;
QuantileHistMock maker_float(cfg, single_precision_histogram);
maker_float.TestSyncHistograms();
}
TEST(QuantileHist, DistributedAddHistRows) {
@@ -608,6 +667,9 @@ TEST(QuantileHist, DistributedAddHistRows) {
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}};
QuantileHistMock maker(cfg, false);
maker.TestAddHistRows();
const bool single_precision_histogram = true;
QuantileHistMock maker_float(cfg, single_precision_histogram);
maker_float.TestAddHistRows();
}
TEST(QuantileHist, DistributedSyncHistograms) {
@@ -615,6 +677,9 @@ TEST(QuantileHist, DistributedSyncHistograms) {
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}};
QuantileHistMock maker(cfg, false);
maker.TestSyncHistograms();
const bool single_precision_histogram = true;
QuantileHistMock maker_float(cfg, single_precision_histogram);
maker_float.TestSyncHistograms();
}
TEST(QuantileHist, BuildHist) {
@@ -624,6 +689,9 @@ TEST(QuantileHist, BuildHist) {
{"enable_feature_grouping", std::to_string(0)}};
QuantileHistMock maker(cfg);
maker.TestBuildHist();
const bool single_precision_histogram = true;
QuantileHistMock maker_float(cfg, single_precision_histogram);
maker_float.TestBuildHist();
}
TEST(QuantileHist, EvalSplits) {
@@ -634,6 +702,9 @@ TEST(QuantileHist, EvalSplits) {
{"min_child_weight", "0"}};
QuantileHistMock maker(cfg);
maker.TestEvaluateSplit();
const bool single_precision_histogram = true;
QuantileHistMock maker_float(cfg, single_precision_histogram);
maker_float.TestEvaluateSplit();
}
TEST(QuantileHist, ApplySplit) {
@@ -644,6 +715,9 @@ TEST(QuantileHist, ApplySplit) {
{"min_child_weight", "0"}};
QuantileHistMock maker(cfg);
maker.TestApplySplit();
const bool single_precision_histogram = true;
QuantileHistMock maker_float(cfg, single_precision_histogram);
maker_float.TestApplySplit();
}
} // namespace tree