Distributed optimizations for 'hist' method with CPUs (#5557)
Co-authored-by: SHVETS, KIRILL <kirill.shvets@intel.com>
This commit is contained in:
parent
e21a608552
commit
dd01e4ba8d
@ -14,7 +14,7 @@
|
|||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
#include "./param.h"
|
#include "./param.h"
|
||||||
#include "../common/io.h"
|
#include "../common/io.h"
|
||||||
|
#include "../common/timer.h"
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
|
|
||||||
@ -25,6 +25,7 @@ class TreePruner: public TreeUpdater {
|
|||||||
public:
|
public:
|
||||||
TreePruner() {
|
TreePruner() {
|
||||||
syncher_.reset(TreeUpdater::Create("sync", tparam_));
|
syncher_.reset(TreeUpdater::Create("sync", tparam_));
|
||||||
|
pruner_monitor_.Init("TreePruner");
|
||||||
}
|
}
|
||||||
char const* Name() const override {
|
char const* Name() const override {
|
||||||
return "prune";
|
return "prune";
|
||||||
@ -52,6 +53,7 @@ class TreePruner: public TreeUpdater {
|
|||||||
void Update(HostDeviceVector<GradientPair> *gpair,
|
void Update(HostDeviceVector<GradientPair> *gpair,
|
||||||
DMatrix *p_fmat,
|
DMatrix *p_fmat,
|
||||||
const std::vector<RegTree*> &trees) override {
|
const std::vector<RegTree*> &trees) override {
|
||||||
|
pruner_monitor_.Start("PrunerUpdate");
|
||||||
// rescale learning rate according to size of trees
|
// rescale learning rate according to size of trees
|
||||||
float lr = param_.learning_rate;
|
float lr = param_.learning_rate;
|
||||||
param_.learning_rate = lr / trees.size();
|
param_.learning_rate = lr / trees.size();
|
||||||
@ -60,6 +62,7 @@ class TreePruner: public TreeUpdater {
|
|||||||
}
|
}
|
||||||
param_.learning_rate = lr;
|
param_.learning_rate = lr;
|
||||||
syncher_->Update(gpair, p_fmat, trees);
|
syncher_->Update(gpair, p_fmat, trees);
|
||||||
|
pruner_monitor_.Stop("PrunerUpdate");
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -105,6 +108,7 @@ class TreePruner: public TreeUpdater {
|
|||||||
std::unique_ptr<TreeUpdater> syncher_;
|
std::unique_ptr<TreeUpdater> syncher_;
|
||||||
// training parameter
|
// training parameter
|
||||||
TrainParam param_;
|
TrainParam param_;
|
||||||
|
common::Monitor pruner_monitor_;
|
||||||
};
|
};
|
||||||
|
|
||||||
XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune")
|
XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune")
|
||||||
|
|||||||
@ -55,12 +55,13 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
|
|||||||
DMatrix *dmat,
|
DMatrix *dmat,
|
||||||
const std::vector<RegTree *> &trees) {
|
const std::vector<RegTree *> &trees) {
|
||||||
if (dmat != p_last_dmat_ || is_gmat_initialized_ == false) {
|
if (dmat != p_last_dmat_ || is_gmat_initialized_ == false) {
|
||||||
|
updater_monitor_.Start("GmatInitialization");
|
||||||
gmat_.Init(dmat, static_cast<uint32_t>(param_.max_bin));
|
gmat_.Init(dmat, static_cast<uint32_t>(param_.max_bin));
|
||||||
column_matrix_.Init(gmat_, param_.sparse_threshold);
|
column_matrix_.Init(gmat_, param_.sparse_threshold);
|
||||||
|
|
||||||
if (param_.enable_feature_grouping > 0) {
|
if (param_.enable_feature_grouping > 0) {
|
||||||
gmatb_.Init(gmat_, column_matrix_, param_);
|
gmatb_.Init(gmat_, column_matrix_, param_);
|
||||||
}
|
}
|
||||||
|
updater_monitor_.Stop("GmatInitialization");
|
||||||
// A proper solution is puting cut matrix in DMatrix, see:
|
// A proper solution is puting cut matrix in DMatrix, see:
|
||||||
// https://github.com/dmlc/xgboost/issues/5143
|
// https://github.com/dmlc/xgboost/issues/5143
|
||||||
is_gmat_initialized_ = true;
|
is_gmat_initialized_ = true;
|
||||||
@ -76,10 +77,18 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
|
|||||||
std::move(pruner_),
|
std::move(pruner_),
|
||||||
std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()),
|
std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()),
|
||||||
int_constraint_, dmat));
|
int_constraint_, dmat));
|
||||||
|
if (rabit::IsDistributed()) {
|
||||||
|
builder_->SetHistSynchronizer(new DistributedHistSynchronizer());
|
||||||
|
builder_->SetHistRowsAdder(new DistributedHistRowsAdder());
|
||||||
|
} else {
|
||||||
|
builder_->SetHistSynchronizer(new BatchHistSynchronizer());
|
||||||
|
builder_->SetHistRowsAdder(new BatchHistRowsAdder());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for (auto tree : trees) {
|
for (auto tree : trees) {
|
||||||
builder_->Update(gmat_, gmatb_, column_matrix_, gpair, dmat, tree);
|
builder_->Update(gmat_, gmatb_, column_matrix_, gpair, dmat, tree);
|
||||||
}
|
}
|
||||||
|
|
||||||
param_.learning_rate = lr;
|
param_.learning_rate = lr;
|
||||||
|
|
||||||
p_last_dmat_ = dmat;
|
p_last_dmat_ = dmat;
|
||||||
@ -95,43 +104,151 @@ bool QuantileHistMaker::UpdatePredictionCache(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void QuantileHistMaker::Builder::SyncHistograms(
|
void BatchHistSynchronizer::SyncHistograms(QuantileHistMaker::Builder* builder,
|
||||||
int starting_index,
|
int starting_index,
|
||||||
int sync_count,
|
int sync_count,
|
||||||
RegTree *p_tree) {
|
RegTree *p_tree) {
|
||||||
builder_monitor_.Start("SyncHistograms");
|
builder->builder_monitor_.Start("SyncHistograms");
|
||||||
|
const size_t nbins = builder->hist_builder_.GetNumBins();
|
||||||
const bool isDistributed = rabit::IsDistributed();
|
common::BlockedSpace2d space(builder->nodes_for_explicit_hist_build_.size(), [&](size_t node) {
|
||||||
|
|
||||||
const size_t nbins = hist_builder_.GetNumBins();
|
|
||||||
common::BlockedSpace2d space(nodes_for_explicit_hist_build_.size(), [&](size_t node) {
|
|
||||||
return nbins;
|
return nbins;
|
||||||
}, 1024);
|
}, 1024);
|
||||||
|
|
||||||
common::ParallelFor2d(space, this->nthread_, [&](size_t node, common::Range1d r) {
|
common::ParallelFor2d(space, builder->nthread_, [&](size_t node, common::Range1d r) {
|
||||||
const auto entry = nodes_for_explicit_hist_build_[node];
|
const auto entry = builder->nodes_for_explicit_hist_build_[node];
|
||||||
auto this_hist = hist_[entry.nid];
|
auto this_hist = builder->hist_[entry.nid];
|
||||||
// Merging histograms from each thread into once
|
// Merging histograms from each thread into once
|
||||||
hist_buffer_.ReduceHist(node, r.begin(), r.end());
|
builder->hist_buffer_.ReduceHist(node, r.begin(), r.end());
|
||||||
|
|
||||||
if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1 && !isDistributed) {
|
|
||||||
auto parent_hist = hist_[(*p_tree)[entry.nid].Parent()];
|
|
||||||
auto sibling_hist = hist_[entry.sibling_nid];
|
|
||||||
|
|
||||||
|
if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1) {
|
||||||
|
const size_t parent_id = (*p_tree)[entry.nid].Parent();
|
||||||
|
auto parent_hist = builder->hist_[parent_id];
|
||||||
|
auto sibling_hist = builder->hist_[entry.sibling_nid];
|
||||||
SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end());
|
SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end());
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
builder->builder_monitor_.Stop("SyncHistograms");
|
||||||
|
}
|
||||||
|
|
||||||
if (isDistributed) {
|
void DistributedHistSynchronizer::SyncHistograms(QuantileHistMaker::Builder* builder,
|
||||||
this->histred_.Allreduce(hist_[starting_index].data(), hist_builder_.GetNumBins() * sync_count);
|
int starting_index,
|
||||||
// use Subtraction Trick
|
int sync_count,
|
||||||
for (auto const& node : nodes_for_subtraction_trick_) {
|
RegTree *p_tree) {
|
||||||
SubtractionTrick(hist_[node.nid], hist_[node.sibling_nid],
|
builder->builder_monitor_.Start("SyncHistograms");
|
||||||
hist_[(*p_tree)[node.nid].Parent()]);
|
const size_t nbins = builder->hist_builder_.GetNumBins();
|
||||||
|
common::BlockedSpace2d space(builder->nodes_for_explicit_hist_build_.size(), [&](size_t node) {
|
||||||
|
return nbins;
|
||||||
|
}, 1024);
|
||||||
|
common::ParallelFor2d(space, builder->nthread_, [&](size_t node, common::Range1d r) {
|
||||||
|
const auto entry = builder->nodes_for_explicit_hist_build_[node];
|
||||||
|
auto this_hist = builder->hist_[entry.nid];
|
||||||
|
// Merging histograms from each thread into once
|
||||||
|
builder->hist_buffer_.ReduceHist(node, r.begin(), r.end());
|
||||||
|
// Store posible parent node
|
||||||
|
auto this_local = builder->hist_local_worker_[entry.nid];
|
||||||
|
CopyHist(this_local, this_hist, r.begin(), r.end());
|
||||||
|
|
||||||
|
if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1) {
|
||||||
|
const size_t parent_id = (*p_tree)[entry.nid].Parent();
|
||||||
|
auto parent_hist = builder->hist_local_worker_[parent_id];
|
||||||
|
auto sibling_hist = builder->hist_[entry.sibling_nid];
|
||||||
|
SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end());
|
||||||
|
// Store posible parent node
|
||||||
|
auto sibling_local = builder->hist_local_worker_[entry.sibling_nid];
|
||||||
|
CopyHist(sibling_local, sibling_hist, r.begin(), r.end());
|
||||||
}
|
}
|
||||||
|
});
|
||||||
|
builder->builder_monitor_.Start("SyncHistogramsAllreduce");
|
||||||
|
builder->histred_.Allreduce(builder->hist_[starting_index].data(),
|
||||||
|
builder->hist_builder_.GetNumBins() * sync_count);
|
||||||
|
builder->builder_monitor_.Stop("SyncHistogramsAllreduce");
|
||||||
|
|
||||||
|
ParallelSubtractionHist(builder, space, builder->nodes_for_explicit_hist_build_, p_tree);
|
||||||
|
|
||||||
|
common::BlockedSpace2d space2(builder->nodes_for_subtraction_trick_.size(), [&](size_t node) {
|
||||||
|
return nbins;
|
||||||
|
}, 1024);
|
||||||
|
ParallelSubtractionHist(builder, space2, builder->nodes_for_subtraction_trick_, p_tree);
|
||||||
|
builder->builder_monitor_.Stop("SyncHistograms");
|
||||||
|
}
|
||||||
|
|
||||||
|
void DistributedHistSynchronizer::ParallelSubtractionHist(QuantileHistMaker::Builder* builder,
|
||||||
|
const common::BlockedSpace2d& space,
|
||||||
|
const std::vector<QuantileHistMaker::Builder::ExpandEntry>& nodes,
|
||||||
|
const RegTree * p_tree) {
|
||||||
|
common::ParallelFor2d(space, builder->nthread_, [&](size_t node, common::Range1d r) {
|
||||||
|
const auto entry = nodes[node];
|
||||||
|
if (!((*p_tree)[entry.nid].IsLeftChild())) {
|
||||||
|
auto this_hist = builder->hist_[entry.nid];
|
||||||
|
|
||||||
|
if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1) {
|
||||||
|
auto parent_hist = builder->hist_[(*p_tree)[entry.nid].Parent()];
|
||||||
|
auto sibling_hist = builder->hist_[entry.sibling_nid];
|
||||||
|
SubtractionHist(this_hist, parent_hist, sibling_hist, r.begin(), r.end());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void BatchHistRowsAdder::AddHistRows(QuantileHistMaker::Builder* builder,
|
||||||
|
int *starting_index, int *sync_count,
|
||||||
|
RegTree *p_tree) {
|
||||||
|
builder->builder_monitor_.Start("AddHistRows");
|
||||||
|
|
||||||
|
for (auto const& entry : builder->nodes_for_explicit_hist_build_) {
|
||||||
|
int nid = entry.nid;
|
||||||
|
builder->hist_.AddHistRow(nid);
|
||||||
|
(*starting_index) = std::min(nid, (*starting_index));
|
||||||
|
}
|
||||||
|
(*sync_count) = builder->nodes_for_explicit_hist_build_.size();
|
||||||
|
|
||||||
|
for (auto const& node : builder->nodes_for_subtraction_trick_) {
|
||||||
|
builder->hist_.AddHistRow(node.nid);
|
||||||
}
|
}
|
||||||
|
|
||||||
builder_monitor_.Stop("SyncHistograms");
|
builder->builder_monitor_.Stop("AddHistRows");
|
||||||
|
}
|
||||||
|
|
||||||
|
void DistributedHistRowsAdder::AddHistRows(QuantileHistMaker::Builder* builder,
|
||||||
|
int *starting_index, int *sync_count,
|
||||||
|
RegTree *p_tree) {
|
||||||
|
builder->builder_monitor_.Start("AddHistRows");
|
||||||
|
const size_t explicit_size = builder->nodes_for_explicit_hist_build_.size();
|
||||||
|
const size_t subtaction_size = builder->nodes_for_subtraction_trick_.size();
|
||||||
|
std::vector<int> merged_node_ids(explicit_size + subtaction_size);
|
||||||
|
for (size_t i = 0; i < explicit_size; ++i) {
|
||||||
|
merged_node_ids[i] = builder->nodes_for_explicit_hist_build_[i].nid;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < subtaction_size; ++i) {
|
||||||
|
merged_node_ids[explicit_size + i] =
|
||||||
|
builder->nodes_for_subtraction_trick_[i].nid;
|
||||||
|
}
|
||||||
|
std::sort(merged_node_ids.begin(), merged_node_ids.end());
|
||||||
|
int n_left = 0;
|
||||||
|
for (auto const& nid : merged_node_ids) {
|
||||||
|
if ((*p_tree)[nid].IsLeftChild()) {
|
||||||
|
builder->hist_.AddHistRow(nid);
|
||||||
|
(*starting_index) = std::min(nid, (*starting_index));
|
||||||
|
n_left++;
|
||||||
|
builder->hist_local_worker_.AddHistRow(nid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (auto const& nid : merged_node_ids) {
|
||||||
|
if (!((*p_tree)[nid].IsLeftChild())) {
|
||||||
|
builder->hist_.AddHistRow(nid);
|
||||||
|
builder->hist_local_worker_.AddHistRow(nid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(*sync_count) = std::max(1, n_left);
|
||||||
|
builder->builder_monitor_.Stop("AddHistRows");
|
||||||
|
}
|
||||||
|
|
||||||
|
void QuantileHistMaker::Builder::SetHistSynchronizer(HistSynchronizer* sync) {
|
||||||
|
hist_synchronizer_.reset(sync);
|
||||||
|
}
|
||||||
|
|
||||||
|
void QuantileHistMaker::Builder::SetHistRowsAdder(HistRowsAdder* adder) {
|
||||||
|
hist_rows_adder_.reset(adder);
|
||||||
}
|
}
|
||||||
|
|
||||||
void QuantileHistMaker::Builder::BuildHistogramsLossGuide(
|
void QuantileHistMaker::Builder::BuildHistogramsLossGuide(
|
||||||
@ -152,30 +269,11 @@ void QuantileHistMaker::Builder::BuildHistogramsLossGuide(
|
|||||||
int starting_index = std::numeric_limits<int>::max();
|
int starting_index = std::numeric_limits<int>::max();
|
||||||
int sync_count = 0;
|
int sync_count = 0;
|
||||||
|
|
||||||
AddHistRows(&starting_index, &sync_count);
|
hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, p_tree);
|
||||||
BuildLocalHistograms(gmat, gmatb, p_tree, gpair_h);
|
BuildLocalHistograms(gmat, gmatb, p_tree, gpair_h);
|
||||||
SyncHistograms(starting_index, sync_count, p_tree);
|
hist_synchronizer_->SyncHistograms(this, starting_index, sync_count, p_tree);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void QuantileHistMaker::Builder::AddHistRows(int *starting_index, int *sync_count) {
|
|
||||||
builder_monitor_.Start("AddHistRows");
|
|
||||||
|
|
||||||
for (auto const& entry : nodes_for_explicit_hist_build_) {
|
|
||||||
int nid = entry.nid;
|
|
||||||
hist_.AddHistRow(nid);
|
|
||||||
(*starting_index) = std::min(nid, (*starting_index));
|
|
||||||
}
|
|
||||||
(*sync_count) = nodes_for_explicit_hist_build_.size();
|
|
||||||
|
|
||||||
for (auto const& node : nodes_for_subtraction_trick_) {
|
|
||||||
hist_.AddHistRow(node.nid);
|
|
||||||
}
|
|
||||||
|
|
||||||
builder_monitor_.Stop("AddHistRows");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
void QuantileHistMaker::Builder::BuildLocalHistograms(
|
void QuantileHistMaker::Builder::BuildLocalHistograms(
|
||||||
const GHistIndexMatrix &gmat,
|
const GHistIndexMatrix &gmat,
|
||||||
const GHistIndexBlockMatrix &gmatb,
|
const GHistIndexBlockMatrix &gmatb,
|
||||||
@ -184,6 +282,7 @@ void QuantileHistMaker::Builder::BuildLocalHistograms(
|
|||||||
builder_monitor_.Start("BuildLocalHistograms");
|
builder_monitor_.Start("BuildLocalHistograms");
|
||||||
|
|
||||||
const size_t n_nodes = nodes_for_explicit_hist_build_.size();
|
const size_t n_nodes = nodes_for_explicit_hist_build_.size();
|
||||||
|
|
||||||
// create space of size (# rows in each node)
|
// create space of size (# rows in each node)
|
||||||
common::BlockedSpace2d space(n_nodes, [&](size_t node) {
|
common::BlockedSpace2d space(n_nodes, [&](size_t node) {
|
||||||
const int32_t nid = nodes_for_explicit_hist_build_[node].nid;
|
const int32_t nid = nodes_for_explicit_hist_build_[node].nid;
|
||||||
@ -305,31 +404,28 @@ void QuantileHistMaker::Builder::SplitSiblings(const std::vector<ExpandEntry>& n
|
|||||||
std::vector<ExpandEntry>* small_siblings,
|
std::vector<ExpandEntry>* small_siblings,
|
||||||
std::vector<ExpandEntry>* big_siblings,
|
std::vector<ExpandEntry>* big_siblings,
|
||||||
RegTree *p_tree) {
|
RegTree *p_tree) {
|
||||||
|
builder_monitor_.Start("SplitSiblings");
|
||||||
for (auto const& entry : nodes) {
|
for (auto const& entry : nodes) {
|
||||||
int nid = entry.nid;
|
int nid = entry.nid;
|
||||||
RegTree::Node &node = (*p_tree)[nid];
|
RegTree::Node &node = (*p_tree)[nid];
|
||||||
if (rabit::IsDistributed()) {
|
if (node.IsRoot()) {
|
||||||
if (node.IsRoot() || node.IsLeftChild()) {
|
small_siblings->push_back(entry);
|
||||||
small_siblings->push_back(entry);
|
|
||||||
} else {
|
|
||||||
big_siblings->push_back(entry);
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
if (!node.IsRoot() && node.IsLeftChild() &&
|
const int32_t left_id = (*p_tree)[node.Parent()].LeftChild();
|
||||||
(row_set_collection_[nid].Size() <
|
const int32_t right_id = (*p_tree)[node.Parent()].RightChild();
|
||||||
row_set_collection_[(*p_tree)[node.Parent()].RightChild()].Size())) {
|
|
||||||
|
if (nid == left_id && row_set_collection_[left_id ].Size() <
|
||||||
|
row_set_collection_[right_id].Size()) {
|
||||||
small_siblings->push_back(entry);
|
small_siblings->push_back(entry);
|
||||||
} else if (!node.IsRoot() && !node.IsLeftChild() &&
|
} else if (nid == right_id && row_set_collection_[right_id].Size() <=
|
||||||
(row_set_collection_[nid].Size() <=
|
row_set_collection_[left_id ].Size()) {
|
||||||
row_set_collection_[(*p_tree)[node.Parent()].LeftChild()].Size())) {
|
|
||||||
small_siblings->push_back(entry);
|
|
||||||
} else if (node.IsRoot()) {
|
|
||||||
small_siblings->push_back(entry);
|
small_siblings->push_back(entry);
|
||||||
} else {
|
} else {
|
||||||
big_siblings->push_back(entry);
|
big_siblings->push_back(entry);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
builder_monitor_.Stop("SplitSiblings");
|
||||||
}
|
}
|
||||||
|
|
||||||
void QuantileHistMaker::Builder::ExpandWithDepthWise(
|
void QuantileHistMaker::Builder::ExpandWithDepthWise(
|
||||||
@ -350,17 +446,16 @@ void QuantileHistMaker::Builder::ExpandWithDepthWise(
|
|||||||
int starting_index = std::numeric_limits<int>::max();
|
int starting_index = std::numeric_limits<int>::max();
|
||||||
int sync_count = 0;
|
int sync_count = 0;
|
||||||
std::vector<ExpandEntry> temp_qexpand_depth;
|
std::vector<ExpandEntry> temp_qexpand_depth;
|
||||||
|
|
||||||
SplitSiblings(qexpand_depth_wise_, &nodes_for_explicit_hist_build_,
|
SplitSiblings(qexpand_depth_wise_, &nodes_for_explicit_hist_build_,
|
||||||
&nodes_for_subtraction_trick_, p_tree);
|
&nodes_for_subtraction_trick_, p_tree);
|
||||||
AddHistRows(&starting_index, &sync_count);
|
hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, p_tree);
|
||||||
|
|
||||||
BuildLocalHistograms(gmat, gmatb, p_tree, gpair_h);
|
BuildLocalHistograms(gmat, gmatb, p_tree, gpair_h);
|
||||||
SyncHistograms(starting_index, sync_count, p_tree);
|
hist_synchronizer_->SyncHistograms(this, starting_index, sync_count, p_tree);
|
||||||
|
|
||||||
BuildNodeStats(gmat, p_fmat, p_tree, gpair_h);
|
BuildNodeStats(gmat, p_fmat, p_tree, gpair_h);
|
||||||
|
|
||||||
EvaluateAndApplySplits(gmat, column_matrix, p_tree, &num_leaves, depth, ×tamp,
|
EvaluateAndApplySplits(gmat, column_matrix, p_tree, &num_leaves, depth, ×tamp,
|
||||||
&temp_qexpand_depth);
|
&temp_qexpand_depth);
|
||||||
|
|
||||||
// clean up
|
// clean up
|
||||||
qexpand_depth_wise_.clear();
|
qexpand_depth_wise_.clear();
|
||||||
nodes_for_subtraction_trick_.clear();
|
nodes_for_subtraction_trick_.clear();
|
||||||
@ -381,7 +476,7 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide(
|
|||||||
DMatrix* p_fmat,
|
DMatrix* p_fmat,
|
||||||
RegTree* p_tree,
|
RegTree* p_tree,
|
||||||
const std::vector<GradientPair>& gpair_h) {
|
const std::vector<GradientPair>& gpair_h) {
|
||||||
|
builder_monitor_.Start("ExpandWithLossGuide");
|
||||||
unsigned timestamp = 0;
|
unsigned timestamp = 0;
|
||||||
int num_leaves = 0;
|
int num_leaves = 0;
|
||||||
|
|
||||||
@ -424,15 +519,10 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide(
|
|||||||
ExpandEntry right_node(cright, cleft, p_tree->GetDepth(cright),
|
ExpandEntry right_node(cright, cleft, p_tree->GetDepth(cright),
|
||||||
0.0f, timestamp++);
|
0.0f, timestamp++);
|
||||||
|
|
||||||
if (rabit::IsDistributed()) {
|
if (row_set_collection_[cleft].Size() < row_set_collection_[cright].Size()) {
|
||||||
// in distributed mode, we need to keep consistent across workers
|
|
||||||
BuildHistogramsLossGuide(left_node, gmat, gmatb, p_tree, gpair_h);
|
BuildHistogramsLossGuide(left_node, gmat, gmatb, p_tree, gpair_h);
|
||||||
} else {
|
} else {
|
||||||
if (row_set_collection_[cleft].Size() < row_set_collection_[cright].Size()) {
|
BuildHistogramsLossGuide(right_node, gmat, gmatb, p_tree, gpair_h);
|
||||||
BuildHistogramsLossGuide(left_node, gmat, gmatb, p_tree, gpair_h);
|
|
||||||
} else {
|
|
||||||
BuildHistogramsLossGuide(right_node, gmat, gmatb, p_tree, gpair_h);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
this->InitNewNode(cleft, gmat, gpair_h, *p_fmat, *p_tree);
|
this->InitNewNode(cleft, gmat, gpair_h, *p_fmat, *p_tree);
|
||||||
@ -452,6 +542,7 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide(
|
|||||||
++num_leaves; // give two and take one, as parent is no longer a leaf
|
++num_leaves; // give two and take one, as parent is no longer a leaf
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
builder_monitor_.Stop("ExpandWithLossGuide");
|
||||||
}
|
}
|
||||||
|
|
||||||
void QuantileHistMaker::Builder::Update(const GHistIndexMatrix& gmat,
|
void QuantileHistMaker::Builder::Update(const GHistIndexMatrix& gmat,
|
||||||
@ -468,7 +559,6 @@ void QuantileHistMaker::Builder::Update(const GHistIndexMatrix& gmat,
|
|||||||
interaction_constraints_.Reset();
|
interaction_constraints_.Reset();
|
||||||
|
|
||||||
this->InitData(gmat, gpair_h, *p_fmat, *p_tree);
|
this->InitData(gmat, gpair_h, *p_fmat, *p_tree);
|
||||||
|
|
||||||
if (param_.grow_policy == TrainParam::kLossGuide) {
|
if (param_.grow_policy == TrainParam::kLossGuide) {
|
||||||
ExpandWithLossGuide(gmat, gmatb, column_matrix, p_fmat, p_tree, gpair_h);
|
ExpandWithLossGuide(gmat, gmatb, column_matrix, p_fmat, p_tree, gpair_h);
|
||||||
} else {
|
} else {
|
||||||
@ -480,7 +570,6 @@ void QuantileHistMaker::Builder::Update(const GHistIndexMatrix& gmat,
|
|||||||
p_tree->Stat(nid).base_weight = snode_[nid].weight;
|
p_tree->Stat(nid).base_weight = snode_[nid].weight;
|
||||||
p_tree->Stat(nid).sum_hess = static_cast<float>(snode_[nid].stats.sum_hess);
|
p_tree->Stat(nid).sum_hess = static_cast<float>(snode_[nid].stats.sum_hess);
|
||||||
}
|
}
|
||||||
|
|
||||||
pruner_->Update(gpair, p_fmat, std::vector<RegTree*>{p_tree});
|
pruner_->Update(gpair, p_fmat, std::vector<RegTree*>{p_tree});
|
||||||
|
|
||||||
builder_monitor_.Stop("Update");
|
builder_monitor_.Stop("Update");
|
||||||
@ -615,6 +704,7 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& gmat,
|
|||||||
// initialize histogram collection
|
// initialize histogram collection
|
||||||
uint32_t nbins = gmat.cut.Ptrs().back();
|
uint32_t nbins = gmat.cut.Ptrs().back();
|
||||||
hist_.Init(nbins);
|
hist_.Init(nbins);
|
||||||
|
hist_local_worker_.Init(nbins);
|
||||||
hist_buffer_.Init(nbins);
|
hist_buffer_.Init(nbins);
|
||||||
|
|
||||||
// initialize histogram builder
|
// initialize histogram builder
|
||||||
@ -1026,18 +1116,15 @@ void QuantileHistMaker::Builder::ApplySplit(const std::vector<ExpandEntry> nodes
|
|||||||
const HistCollection& hist,
|
const HistCollection& hist,
|
||||||
RegTree* p_tree) {
|
RegTree* p_tree) {
|
||||||
builder_monitor_.Start("ApplySplit");
|
builder_monitor_.Start("ApplySplit");
|
||||||
|
|
||||||
// 1. Find split condition for each split
|
// 1. Find split condition for each split
|
||||||
const size_t n_nodes = nodes.size();
|
const size_t n_nodes = nodes.size();
|
||||||
std::vector<int32_t> split_conditions;
|
std::vector<int32_t> split_conditions;
|
||||||
FindSplitConditions(nodes, *p_tree, gmat, &split_conditions);
|
FindSplitConditions(nodes, *p_tree, gmat, &split_conditions);
|
||||||
|
|
||||||
// 2.1 Create a blocked space of size SUM(samples in each node)
|
// 2.1 Create a blocked space of size SUM(samples in each node)
|
||||||
common::BlockedSpace2d space(n_nodes, [&](size_t node_in_set) {
|
common::BlockedSpace2d space(n_nodes, [&](size_t node_in_set) {
|
||||||
int32_t nid = nodes[node_in_set].nid;
|
int32_t nid = nodes[node_in_set].nid;
|
||||||
return row_set_collection_[nid].Size();
|
return row_set_collection_[nid].Size();
|
||||||
}, kPartitionBlockSize);
|
}, kPartitionBlockSize);
|
||||||
|
|
||||||
// 2.2 Initialize the partition builder
|
// 2.2 Initialize the partition builder
|
||||||
// allocate buffers for storage intermediate results by each thread
|
// allocate buffers for storage intermediate results by each thread
|
||||||
partition_builder_.Init(space.Size(), n_nodes, [&](size_t node_in_set) {
|
partition_builder_.Init(space.Size(), n_nodes, [&](size_t node_in_set) {
|
||||||
@ -1046,7 +1133,6 @@ void QuantileHistMaker::Builder::ApplySplit(const std::vector<ExpandEntry> nodes
|
|||||||
const size_t n_tasks = size / kPartitionBlockSize + !!(size % kPartitionBlockSize);
|
const size_t n_tasks = size / kPartitionBlockSize + !!(size % kPartitionBlockSize);
|
||||||
return n_tasks;
|
return n_tasks;
|
||||||
});
|
});
|
||||||
|
|
||||||
// 2.3 Split elements of row_set_collection_ to left and right child-nodes for each node
|
// 2.3 Split elements of row_set_collection_ to left and right child-nodes for each node
|
||||||
// Store results in intermediate buffers from partition_builder_
|
// Store results in intermediate buffers from partition_builder_
|
||||||
common::ParallelFor2d(space, this->nthread_, [&](size_t node_in_set, common::Range1d r) {
|
common::ParallelFor2d(space, this->nthread_, [&](size_t node_in_set, common::Range1d r) {
|
||||||
@ -1068,7 +1154,6 @@ void QuantileHistMaker::Builder::ApplySplit(const std::vector<ExpandEntry> nodes
|
|||||||
CHECK(false); // no default behavior
|
CHECK(false); // no default behavior
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// 3. Compute offsets to copy blocks of row-indexes
|
// 3. Compute offsets to copy blocks of row-indexes
|
||||||
// from partition_builder_ to row_set_collection_
|
// from partition_builder_ to row_set_collection_
|
||||||
partition_builder_.CalculateRowOffsets();
|
partition_builder_.CalculateRowOffsets();
|
||||||
@ -1080,10 +1165,8 @@ void QuantileHistMaker::Builder::ApplySplit(const std::vector<ExpandEntry> nodes
|
|||||||
partition_builder_.MergeToArray(node_in_set, r.begin(),
|
partition_builder_.MergeToArray(node_in_set, r.begin(),
|
||||||
const_cast<size_t*>(row_set_collection_[nid].begin));
|
const_cast<size_t*>(row_set_collection_[nid].begin));
|
||||||
});
|
});
|
||||||
|
|
||||||
// 5. Add info about splits into row_set_collection_
|
// 5. Add info about splits into row_set_collection_
|
||||||
AddSplitsToRowSet(nodes, p_tree);
|
AddSplitsToRowSet(nodes, p_tree);
|
||||||
|
|
||||||
builder_monitor_.Stop("ApplySplit");
|
builder_monitor_.Stop("ApplySplit");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -78,10 +78,14 @@ using xgboost::common::GHistBuilder;
|
|||||||
using xgboost::common::ColumnMatrix;
|
using xgboost::common::ColumnMatrix;
|
||||||
using xgboost::common::Column;
|
using xgboost::common::Column;
|
||||||
|
|
||||||
|
class HistSynchronizer;
|
||||||
|
class HistRowsAdder;
|
||||||
/*! \brief construct a tree using quantized feature values */
|
/*! \brief construct a tree using quantized feature values */
|
||||||
class QuantileHistMaker: public TreeUpdater {
|
class QuantileHistMaker: public TreeUpdater {
|
||||||
public:
|
public:
|
||||||
QuantileHistMaker() = default;
|
QuantileHistMaker() {
|
||||||
|
updater_monitor_.Init("QuantileHistMaker");
|
||||||
|
}
|
||||||
void Configure(const Args& args) override;
|
void Configure(const Args& args) override;
|
||||||
|
|
||||||
void Update(HostDeviceVector<GradientPair>* gpair,
|
void Update(HostDeviceVector<GradientPair>* gpair,
|
||||||
@ -105,6 +109,12 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
friend class HistSynchronizer;
|
||||||
|
friend class BatchHistSynchronizer;
|
||||||
|
friend class DistributedHistSynchronizer;
|
||||||
|
friend class HistRowsAdder;
|
||||||
|
friend class BatchHistRowsAdder;
|
||||||
|
friend class DistributedHistRowsAdder;
|
||||||
// training parameter
|
// training parameter
|
||||||
TrainParam param_;
|
TrainParam param_;
|
||||||
// quantized data matrix
|
// quantized data matrix
|
||||||
@ -174,8 +184,16 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
|
|
||||||
bool UpdatePredictionCache(const DMatrix* data,
|
bool UpdatePredictionCache(const DMatrix* data,
|
||||||
HostDeviceVector<bst_float>* p_out_preds);
|
HostDeviceVector<bst_float>* p_out_preds);
|
||||||
|
void SetHistSynchronizer(HistSynchronizer* sync);
|
||||||
|
void SetHistRowsAdder(HistRowsAdder* adder);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
friend class HistSynchronizer;
|
||||||
|
friend class BatchHistSynchronizer;
|
||||||
|
friend class DistributedHistSynchronizer;
|
||||||
|
friend class HistRowsAdder;
|
||||||
|
friend class BatchHistRowsAdder;
|
||||||
|
friend class DistributedHistRowsAdder;
|
||||||
/* tree growing policies */
|
/* tree growing policies */
|
||||||
struct ExpandEntry {
|
struct ExpandEntry {
|
||||||
static const int kRootNid = 0;
|
static const int kRootNid = 0;
|
||||||
@ -259,8 +277,6 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
RegTree *p_tree,
|
RegTree *p_tree,
|
||||||
const std::vector<GradientPair> &gpair_h);
|
const std::vector<GradientPair> &gpair_h);
|
||||||
|
|
||||||
void AddHistRows(int *starting_index, int *sync_count);
|
|
||||||
|
|
||||||
void BuildHistogramsLossGuide(
|
void BuildHistogramsLossGuide(
|
||||||
ExpandEntry entry,
|
ExpandEntry entry,
|
||||||
const GHistIndexMatrix &gmat,
|
const GHistIndexMatrix &gmat,
|
||||||
@ -276,9 +292,9 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
std::vector<ExpandEntry>* big_siblings,
|
std::vector<ExpandEntry>* big_siblings,
|
||||||
RegTree *p_tree);
|
RegTree *p_tree);
|
||||||
|
|
||||||
void SyncHistograms(int starting_index,
|
void ParallelSubtractionHist(const common::BlockedSpace2d& space,
|
||||||
int sync_count,
|
const std::vector<ExpandEntry>& nodes,
|
||||||
RegTree *p_tree);
|
const RegTree * p_tree);
|
||||||
|
|
||||||
void BuildNodeStats(const GHistIndexMatrix &gmat,
|
void BuildNodeStats(const GHistIndexMatrix &gmat,
|
||||||
DMatrix *p_fmat,
|
DMatrix *p_fmat,
|
||||||
@ -316,7 +332,6 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
return lhs.loss_chg < rhs.loss_chg; // favor large loss_chg
|
return lhs.loss_chg < rhs.loss_chg; // favor large loss_chg
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// --data fields--
|
// --data fields--
|
||||||
const TrainParam& param_;
|
const TrainParam& param_;
|
||||||
// number of omp thread used during training
|
// number of omp thread used during training
|
||||||
@ -331,6 +346,8 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
std::vector<NodeEntry> snode_;
|
std::vector<NodeEntry> snode_;
|
||||||
/*! \brief culmulative histogram of gradients. */
|
/*! \brief culmulative histogram of gradients. */
|
||||||
HistCollection hist_;
|
HistCollection hist_;
|
||||||
|
/*! \brief culmulative local parent histogram of gradients. */
|
||||||
|
HistCollection hist_local_worker_;
|
||||||
/*! \brief feature with least # of bins. to be used for dense specialization
|
/*! \brief feature with least # of bins. to be used for dense specialization
|
||||||
of InitNewNode() */
|
of InitNewNode() */
|
||||||
uint32_t fid_least_bins_;
|
uint32_t fid_least_bins_;
|
||||||
@ -367,14 +384,62 @@ class QuantileHistMaker: public TreeUpdater {
|
|||||||
common::Monitor builder_monitor_;
|
common::Monitor builder_monitor_;
|
||||||
common::ParallelGHistBuilder hist_buffer_;
|
common::ParallelGHistBuilder hist_buffer_;
|
||||||
rabit::Reducer<GradStats, GradStats::Reduce> histred_;
|
rabit::Reducer<GradStats, GradStats::Reduce> histred_;
|
||||||
|
std::unique_ptr<HistSynchronizer> hist_synchronizer_;
|
||||||
|
std::unique_ptr<HistRowsAdder> hist_rows_adder_;
|
||||||
};
|
};
|
||||||
|
common::Monitor updater_monitor_;
|
||||||
std::unique_ptr<Builder> builder_;
|
std::unique_ptr<Builder> builder_;
|
||||||
std::unique_ptr<TreeUpdater> pruner_;
|
std::unique_ptr<TreeUpdater> pruner_;
|
||||||
std::unique_ptr<SplitEvaluator> spliteval_;
|
std::unique_ptr<SplitEvaluator> spliteval_;
|
||||||
FeatureInteractionConstraintHost int_constraint_;
|
FeatureInteractionConstraintHost int_constraint_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class HistSynchronizer {
|
||||||
|
public:
|
||||||
|
virtual void SyncHistograms(QuantileHistMaker::Builder* builder,
|
||||||
|
int starting_index,
|
||||||
|
int sync_count,
|
||||||
|
RegTree *p_tree) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class BatchHistSynchronizer: public HistSynchronizer {
|
||||||
|
public:
|
||||||
|
void SyncHistograms(QuantileHistMaker::Builder* builder,
|
||||||
|
int starting_index,
|
||||||
|
int sync_count,
|
||||||
|
RegTree *p_tree) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class DistributedHistSynchronizer: public HistSynchronizer {
|
||||||
|
public:
|
||||||
|
void SyncHistograms(QuantileHistMaker::Builder* builder_,
|
||||||
|
int starting_index, int sync_count, RegTree *p_tree) override;
|
||||||
|
|
||||||
|
void ParallelSubtractionHist(QuantileHistMaker::Builder* builder,
|
||||||
|
const common::BlockedSpace2d& space,
|
||||||
|
const std::vector<QuantileHistMaker::Builder::ExpandEntry>& nodes,
|
||||||
|
const RegTree * p_tree);
|
||||||
|
};
|
||||||
|
|
||||||
|
class HistRowsAdder {
|
||||||
|
public:
|
||||||
|
virtual void AddHistRows(QuantileHistMaker::Builder* builder,
|
||||||
|
int *starting_index, int *sync_count, RegTree *p_tree) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class BatchHistRowsAdder: public HistRowsAdder {
|
||||||
|
public:
|
||||||
|
void AddHistRows(QuantileHistMaker::Builder* builder,
|
||||||
|
int *starting_index, int *sync_count, RegTree *p_tree) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class DistributedHistRowsAdder: public HistRowsAdder {
|
||||||
|
public:
|
||||||
|
void AddHistRows(QuantileHistMaker::Builder* builder,
|
||||||
|
int *starting_index, int *sync_count, RegTree *p_tree) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
|
|||||||
@ -29,7 +29,8 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
std::unique_ptr<SplitEvaluator> spliteval,
|
std::unique_ptr<SplitEvaluator> spliteval,
|
||||||
FeatureInteractionConstraintHost int_constraint,
|
FeatureInteractionConstraintHost int_constraint,
|
||||||
DMatrix const* fmat)
|
DMatrix const* fmat)
|
||||||
: RealImpl(param, std::move(pruner), std::move(spliteval), std::move(int_constraint), fmat) {}
|
: RealImpl(param, std::move(pruner), std::move(spliteval),
|
||||||
|
std::move(int_constraint), fmat) {}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
void TestInitData(const GHistIndexMatrix& gmat,
|
void TestInitData(const GHistIndexMatrix& gmat,
|
||||||
@ -120,6 +121,147 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
omp_set_num_threads(nthreads);
|
omp_set_num_threads(nthreads);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TestAddHistRows(const GHistIndexMatrix& gmat,
|
||||||
|
const std::vector<GradientPair>& gpair,
|
||||||
|
DMatrix* p_fmat,
|
||||||
|
RegTree* tree) {
|
||||||
|
RealImpl::InitData(gmat, gpair, *p_fmat, *tree);
|
||||||
|
|
||||||
|
int starting_index = std::numeric_limits<int>::max();
|
||||||
|
int sync_count = 0;
|
||||||
|
nodes_for_explicit_hist_build_.clear();
|
||||||
|
nodes_for_subtraction_trick_.clear();
|
||||||
|
|
||||||
|
tree->ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||||
|
tree->ExpandNode((*tree)[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||||
|
tree->ExpandNode((*tree)[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||||
|
nodes_for_explicit_hist_build_.emplace_back(3, 4, tree->GetDepth(3), 0.0f, 0);
|
||||||
|
nodes_for_explicit_hist_build_.emplace_back(4, 3, tree->GetDepth(4), 0.0f, 0);
|
||||||
|
nodes_for_subtraction_trick_.emplace_back(5, 6, tree->GetDepth(5), 0.0f, 0);
|
||||||
|
nodes_for_subtraction_trick_.emplace_back(6, 5, tree->GetDepth(6), 0.0f, 0);
|
||||||
|
|
||||||
|
hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
|
||||||
|
ASSERT_EQ(sync_count, 2);
|
||||||
|
ASSERT_EQ(starting_index, 3);
|
||||||
|
|
||||||
|
for (const ExpandEntry& node : nodes_for_explicit_hist_build_) {
|
||||||
|
ASSERT_EQ(hist_.RowExists(node.nid), true);
|
||||||
|
}
|
||||||
|
for (const ExpandEntry& node : nodes_for_subtraction_trick_) {
|
||||||
|
ASSERT_EQ(hist_.RowExists(node.nid), true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void TestSyncHistograms(const GHistIndexMatrix& gmat,
|
||||||
|
const std::vector<GradientPair>& gpair,
|
||||||
|
DMatrix* p_fmat,
|
||||||
|
RegTree* tree) {
|
||||||
|
// init
|
||||||
|
RealImpl::InitData(gmat, gpair, *p_fmat, *tree);
|
||||||
|
|
||||||
|
int starting_index = std::numeric_limits<int>::max();
|
||||||
|
int sync_count = 0;
|
||||||
|
nodes_for_explicit_hist_build_.clear();
|
||||||
|
nodes_for_subtraction_trick_.clear();
|
||||||
|
// level 0
|
||||||
|
nodes_for_explicit_hist_build_.emplace_back(0, -1, tree->GetDepth(0), 0.0f, 0);
|
||||||
|
hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
|
||||||
|
tree->ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||||
|
|
||||||
|
nodes_for_explicit_hist_build_.clear();
|
||||||
|
nodes_for_subtraction_trick_.clear();
|
||||||
|
// level 1
|
||||||
|
nodes_for_explicit_hist_build_.emplace_back((*tree)[0].LeftChild(), (*tree)[0].RightChild(),
|
||||||
|
tree->GetDepth(1), 0.0f, 0);
|
||||||
|
nodes_for_subtraction_trick_.emplace_back((*tree)[0].RightChild(), (*tree)[0].LeftChild(),
|
||||||
|
tree->GetDepth(2), 0.0f, 0);
|
||||||
|
hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
|
||||||
|
tree->ExpandNode((*tree)[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||||
|
tree->ExpandNode((*tree)[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
|
||||||
|
|
||||||
|
nodes_for_explicit_hist_build_.clear();
|
||||||
|
nodes_for_subtraction_trick_.clear();
|
||||||
|
// level 2
|
||||||
|
nodes_for_explicit_hist_build_.emplace_back(3, 4, tree->GetDepth(3), 0.0f, 0);
|
||||||
|
nodes_for_subtraction_trick_.emplace_back(4, 3, tree->GetDepth(4), 0.0f, 0);
|
||||||
|
nodes_for_explicit_hist_build_.emplace_back(5, 6, tree->GetDepth(5), 0.0f, 0);
|
||||||
|
nodes_for_subtraction_trick_.emplace_back(6, 5, tree->GetDepth(6), 0.0f, 0);
|
||||||
|
hist_rows_adder_->AddHistRows(this, &starting_index, &sync_count, tree);
|
||||||
|
|
||||||
|
const size_t n_nodes = nodes_for_explicit_hist_build_.size();
|
||||||
|
ASSERT_EQ(n_nodes, 2);
|
||||||
|
row_set_collection_.AddSplit(0, (*tree)[0].LeftChild(),
|
||||||
|
(*tree)[0].RightChild(), 4, 4);
|
||||||
|
row_set_collection_.AddSplit(1, (*tree)[1].LeftChild(),
|
||||||
|
(*tree)[1].RightChild(), 2, 2);
|
||||||
|
row_set_collection_.AddSplit(2, (*tree)[2].LeftChild(),
|
||||||
|
(*tree)[2].RightChild(), 2, 2);
|
||||||
|
|
||||||
|
common::BlockedSpace2d space(n_nodes, [&](size_t node) {
|
||||||
|
const int32_t nid = nodes_for_explicit_hist_build_[node].nid;
|
||||||
|
return row_set_collection_[nid].Size();
|
||||||
|
}, 256);
|
||||||
|
|
||||||
|
std::vector<GHistRow> target_hists(n_nodes);
|
||||||
|
for (size_t i = 0; i < nodes_for_explicit_hist_build_.size(); ++i) {
|
||||||
|
const int32_t nid = nodes_for_explicit_hist_build_[i].nid;
|
||||||
|
target_hists[i] = hist_[nid];
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t nbins = hist_builder_.GetNumBins();
|
||||||
|
// set values to specific nodes hist
|
||||||
|
std::vector<size_t> n_ids = {1, 2};
|
||||||
|
for (size_t i : n_ids) {
|
||||||
|
auto this_hist = hist_[i];
|
||||||
|
using FPType = decltype(tree::GradStats::sum_grad);
|
||||||
|
FPType* p_hist = reinterpret_cast<FPType*>(this_hist.data());
|
||||||
|
for (size_t bin_id = 0; bin_id < 2*nbins; ++bin_id) {
|
||||||
|
p_hist[bin_id] = 2*bin_id;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
n_ids[0] = 3;
|
||||||
|
n_ids[1] = 5;
|
||||||
|
for (size_t i : n_ids) {
|
||||||
|
auto this_hist = hist_[i];
|
||||||
|
using FPType = decltype(tree::GradStats::sum_grad);
|
||||||
|
FPType* p_hist = reinterpret_cast<FPType*>(this_hist.data());
|
||||||
|
for (size_t bin_id = 0; bin_id < 2*nbins; ++bin_id) {
|
||||||
|
p_hist[bin_id] = bin_id;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
hist_buffer_.Reset(1, n_nodes, space, target_hists);
|
||||||
|
// sync hist
|
||||||
|
hist_synchronizer_->SyncHistograms(this, starting_index, sync_count, tree);
|
||||||
|
|
||||||
|
auto check_hist = [] (const GHistRow parent, const GHistRow left,
|
||||||
|
const GHistRow right, size_t begin, size_t end) {
|
||||||
|
using FPType = decltype(tree::GradStats::sum_grad);
|
||||||
|
const FPType* p_parent = reinterpret_cast<const FPType*>(parent.data());
|
||||||
|
const FPType* p_left = reinterpret_cast<const FPType*>(left.data());
|
||||||
|
const FPType* p_right = reinterpret_cast<const FPType*>(right.data());
|
||||||
|
for (size_t i = 2 * begin; i < 2 * end; ++i) {
|
||||||
|
ASSERT_EQ(p_parent[i], p_left[i] + p_right[i]);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
for (const ExpandEntry& node : nodes_for_explicit_hist_build_) {
|
||||||
|
auto this_hist = hist_[node.nid];
|
||||||
|
const size_t parent_id = (*tree)[node.nid].Parent();
|
||||||
|
auto parent_hist = hist_[parent_id];
|
||||||
|
auto sibling_hist = hist_[node.sibling_nid];
|
||||||
|
|
||||||
|
check_hist(parent_hist, this_hist, sibling_hist, 0, nbins);
|
||||||
|
}
|
||||||
|
for (const ExpandEntry& node : nodes_for_subtraction_trick_) {
|
||||||
|
auto this_hist = hist_[node.nid];
|
||||||
|
const size_t parent_id = (*tree)[node.nid].Parent();
|
||||||
|
auto parent_hist = hist_[parent_id];
|
||||||
|
auto sibling_hist = hist_[node.sibling_nid];
|
||||||
|
|
||||||
|
check_hist(parent_hist, this_hist, sibling_hist, 0, nbins);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void TestBuildHist(int nid,
|
void TestBuildHist(int nid,
|
||||||
const GHistIndexMatrix& gmat,
|
const GHistIndexMatrix& gmat,
|
||||||
@ -324,7 +466,7 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
explicit QuantileHistMock(
|
explicit QuantileHistMock(
|
||||||
const std::vector<std::pair<std::string, std::string> >& args) :
|
const std::vector<std::pair<std::string, std::string> >& args, bool batch = true) :
|
||||||
cfg_{args} {
|
cfg_{args} {
|
||||||
QuantileHistMaker::Configure(args);
|
QuantileHistMaker::Configure(args);
|
||||||
spliteval_->Init(¶m_);
|
spliteval_->Init(¶m_);
|
||||||
@ -336,6 +478,13 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()),
|
std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()),
|
||||||
int_constraint_,
|
int_constraint_,
|
||||||
dmat_.get()));
|
dmat_.get()));
|
||||||
|
if (batch) {
|
||||||
|
builder_->SetHistSynchronizer(new BatchHistSynchronizer());
|
||||||
|
builder_->SetHistRowsAdder(new BatchHistRowsAdder());
|
||||||
|
} else {
|
||||||
|
builder_->SetHistSynchronizer(new DistributedHistSynchronizer());
|
||||||
|
builder_->SetHistRowsAdder(new DistributedHistRowsAdder());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
~QuantileHistMock() override = default;
|
~QuantileHistMock() override = default;
|
||||||
|
|
||||||
@ -370,6 +519,34 @@ class QuantileHistMock : public QuantileHistMaker {
|
|||||||
|
|
||||||
builder_->TestInitDataSampling(gmat, gpair, dmat_.get(), tree);
|
builder_->TestInitDataSampling(gmat, gpair, dmat_.get(), tree);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TestAddHistRows() {
|
||||||
|
size_t constexpr kMaxBins = 4;
|
||||||
|
common::GHistIndexMatrix gmat;
|
||||||
|
gmat.Init(dmat_.get(), kMaxBins);
|
||||||
|
|
||||||
|
RegTree tree = RegTree();
|
||||||
|
tree.param.UpdateAllowUnknown(cfg_);
|
||||||
|
std::vector<GradientPair> gpair =
|
||||||
|
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
|
||||||
|
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
|
||||||
|
builder_->TestAddHistRows(gmat, gpair, dmat_.get(), &tree);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestSyncHistograms() {
|
||||||
|
size_t constexpr kMaxBins = 4;
|
||||||
|
common::GHistIndexMatrix gmat;
|
||||||
|
gmat.Init(dmat_.get(), kMaxBins);
|
||||||
|
|
||||||
|
RegTree tree = RegTree();
|
||||||
|
tree.param.UpdateAllowUnknown(cfg_);
|
||||||
|
std::vector<GradientPair> gpair =
|
||||||
|
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
|
||||||
|
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
|
||||||
|
builder_->TestSyncHistograms(gmat, gpair, dmat_.get(), &tree);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
void TestBuildHist() {
|
void TestBuildHist() {
|
||||||
RegTree tree = RegTree();
|
RegTree tree = RegTree();
|
||||||
tree.param.UpdateAllowUnknown(cfg_);
|
tree.param.UpdateAllowUnknown(cfg_);
|
||||||
@ -412,6 +589,34 @@ TEST(QuantileHist, InitDataSampling) {
|
|||||||
maker.TestInitDataSampling();
|
maker.TestInitDataSampling();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(QuantileHist, AddHistRows) {
|
||||||
|
std::vector<std::pair<std::string, std::string>> cfg
|
||||||
|
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}};
|
||||||
|
QuantileHistMock maker(cfg);
|
||||||
|
maker.TestAddHistRows();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(QuantileHist, SyncHistograms) {
|
||||||
|
std::vector<std::pair<std::string, std::string>> cfg
|
||||||
|
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}};
|
||||||
|
QuantileHistMock maker(cfg);
|
||||||
|
maker.TestSyncHistograms();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(QuantileHist, DistributedAddHistRows) {
|
||||||
|
std::vector<std::pair<std::string, std::string>> cfg
|
||||||
|
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}};
|
||||||
|
QuantileHistMock maker(cfg, false);
|
||||||
|
maker.TestAddHistRows();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(QuantileHist, DistributedSyncHistograms) {
|
||||||
|
std::vector<std::pair<std::string, std::string>> cfg
|
||||||
|
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}};
|
||||||
|
QuantileHistMock maker(cfg, false);
|
||||||
|
maker.TestSyncHistograms();
|
||||||
|
}
|
||||||
|
|
||||||
TEST(QuantileHist, BuildHist) {
|
TEST(QuantileHist, BuildHist) {
|
||||||
// Don't enable feature grouping
|
// Don't enable feature grouping
|
||||||
std::vector<std::pair<std::string, std::string>> cfg
|
std::vector<std::pair<std::string, std::string>> cfg
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user