Fuse gpu_hist all-reduce calls where possible (#7867)

This commit is contained in:
Rory Mitchell
2022-05-17 13:27:50 +02:00
committed by GitHub
parent b41cf92dc2
commit 71d3b2e036
9 changed files with 234 additions and 185 deletions

View File

@@ -62,7 +62,7 @@ DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam);
#endif // !defined(GTEST_TEST)
/**
* \struct DeviceHistogram
* \struct DeviceHistogramStorage
*
* \summary Data storage for node histograms on device. Automatically expands.
*
@@ -72,20 +72,27 @@ DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam);
* \author Rory
* \date 28/07/2018
*/
template <typename GradientSumT, size_t kStopGrowingSize = 1 << 26>
class DeviceHistogram {
template <typename GradientSumT, size_t kStopGrowingSize = 1 << 28>
class DeviceHistogramStorage {
private:
/*! \brief Map nidx to starting index of its histogram. */
std::map<int, size_t> nidx_map_;
// Large buffer of zeroed memory, caches histograms
dh::device_vector<typename GradientSumT::ValueT> data_;
// If we run out of storage allocate one histogram at a time
// in overflow. Not cached, overwritten when a new histogram
// is requested
dh::device_vector<typename GradientSumT::ValueT> overflow_;
std::map<int, size_t> overflow_nidx_map_;
int n_bins_;
int device_id_;
static constexpr size_t kNumItemsInGradientSum =
sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT);
static_assert(kNumItemsInGradientSum == 2,
"Number of items in gradient type should be 2.");
static_assert(kNumItemsInGradientSum == 2, "Number of items in gradient type should be 2.");
public:
// Start with about 16mb
DeviceHistogramStorage() { data_.reserve(1 << 22); }
void Init(int device_id, int n_bins) {
this->n_bins_ = n_bins;
this->device_id_ = device_id;
@@ -93,52 +100,47 @@ class DeviceHistogram {
void Reset() {
auto d_data = data_.data().get();
dh::LaunchN(data_.size(),
[=] __device__(size_t idx) { d_data[idx] = 0.0f; });
dh::LaunchN(data_.size(), [=] __device__(size_t idx) { d_data[idx] = 0.0f; });
nidx_map_.clear();
overflow_nidx_map_.clear();
}
bool HistogramExists(int nidx) const {
return nidx_map_.find(nidx) != nidx_map_.cend();
}
int Bins() const {
return n_bins_;
}
size_t HistogramSize() const {
return n_bins_ * kNumItemsInGradientSum;
return nidx_map_.find(nidx) != nidx_map_.cend() ||
overflow_nidx_map_.find(nidx) != overflow_nidx_map_.cend();
}
int Bins() const { return n_bins_; }
size_t HistogramSize() const { return n_bins_ * kNumItemsInGradientSum; }
dh::device_vector<typename GradientSumT::ValueT>& Data() { return data_; }
dh::device_vector<typename GradientSumT::ValueT>& Data() {
return data_;
}
void AllocateHistogram(int nidx) {
if (HistogramExists(nidx)) return;
void AllocateHistograms(const std::vector<int>& new_nidxs) {
for (int nidx : new_nidxs) {
CHECK(!HistogramExists(nidx));
}
// Number of items currently used in data
const size_t used_size = nidx_map_.size() * HistogramSize();
const size_t new_used_size = used_size + HistogramSize();
if (data_.size() >= kStopGrowingSize) {
// Recycle histogram memory
if (new_used_size <= data_.size()) {
// no need to remove old node, just insert the new one.
nidx_map_[nidx] = used_size;
// memset histogram size in bytes
} else {
std::pair<int, size_t> old_entry = *nidx_map_.begin();
nidx_map_.erase(old_entry.first);
nidx_map_[nidx] = old_entry.second;
const size_t new_used_size = used_size + HistogramSize() * new_nidxs.size();
if (used_size >= kStopGrowingSize) {
// Use overflow
// Delete previous entries
overflow_nidx_map_.clear();
overflow_.resize(HistogramSize() * new_nidxs.size());
// Zero memory
auto d_data = overflow_.data().get();
dh::LaunchN(overflow_.size(),
[=] __device__(size_t idx) { d_data[idx] = 0.0; });
// Append new histograms
for (int nidx : new_nidxs) {
overflow_nidx_map_[nidx] = overflow_nidx_map_.size() * HistogramSize();
}
// Zero recycled memory
auto d_data = data_.data().get() + nidx_map_[nidx];
dh::LaunchN(n_bins_ * 2,
[=] __device__(size_t idx) { d_data[idx] = 0.0f; });
} else {
// Append new node histogram
nidx_map_[nidx] = used_size;
// Check there is enough memory for another histogram node
if (data_.size() < new_used_size + HistogramSize()) {
size_t new_required_memory =
std::max(data_.size() * 2, HistogramSize());
data_.resize(new_required_memory);
CHECK_GE(data_.size(), used_size);
// Expand if necessary
if (data_.size() < new_used_size) {
data_.resize(std::max(data_.size() * 2, new_used_size));
}
// Append new histograms
for (int nidx : new_nidxs) {
nidx_map_[nidx] = nidx_map_.size() * HistogramSize();
}
}
@@ -152,9 +154,16 @@ class DeviceHistogram {
*/
common::Span<GradientSumT> GetNodeHistogram(int nidx) {
CHECK(this->HistogramExists(nidx));
auto ptr = data_.data().get() + nidx_map_.at(nidx);
return common::Span<GradientSumT>(
reinterpret_cast<GradientSumT*>(ptr), n_bins_);
if (nidx_map_.find(nidx) != nidx_map_.cend()) {
// Fetch from normal cache
auto ptr = data_.data().get() + nidx_map_.at(nidx);
return common::Span<GradientSumT>(reinterpret_cast<GradientSumT*>(ptr), n_bins_);
} else {
// Fetch from overflow
auto ptr = overflow_.data().get() + overflow_nidx_map_.at(nidx);
return common::Span<GradientSumT>(reinterpret_cast<GradientSumT*>(ptr), n_bins_);
}
}
};
@@ -171,7 +180,7 @@ struct GPUHistMakerDevice {
BatchParam batch_param;
std::unique_ptr<RowPartitioner> row_partitioner;
DeviceHistogram<GradientSumT> hist{};
DeviceHistogramStorage<GradientSumT> hist{};
dh::caching_device_vector<GradientPair> d_gpair; // storage for gpair;
common::Span<GradientPair> gpair;
@@ -195,6 +204,7 @@ struct GPUHistMakerDevice {
std::unique_ptr<FeatureGroups> feature_groups;
GPUHistMakerDevice(Context const* ctx, EllpackPageImpl const* _page,
common::Span<FeatureType const> _feature_types, bst_uint _n_rows,
TrainParam _param, uint32_t column_sampler_seed, uint32_t n_features,
@@ -322,7 +332,6 @@ struct GPUHistMakerDevice {
}
void BuildHist(int nidx) {
hist.AllocateHistogram(nidx);
auto d_node_hist = hist.GetNodeHistogram(nidx);
auto d_ridx = row_partitioner->GetRows(nidx);
BuildGradientHistogram(page->GetDeviceAccessor(ctx_->gpu_id),
@@ -330,8 +339,12 @@ struct GPUHistMakerDevice {
d_ridx, d_node_hist, histogram_rounding);
}
void SubtractionTrick(int nidx_parent, int nidx_histogram,
int nidx_subtraction) {
// Attempt to do subtraction trick
// return true if succeeded
bool SubtractionTrick(int nidx_parent, int nidx_histogram, int nidx_subtraction) {
if (!hist.HistogramExists(nidx_histogram) || !hist.HistogramExists(nidx_parent)) {
return false;
}
auto d_node_hist_parent = hist.GetNodeHistogram(nidx_parent);
auto d_node_hist_histogram = hist.GetNodeHistogram(nidx_histogram);
auto d_node_hist_subtraction = hist.GetNodeHistogram(nidx_subtraction);
@@ -340,12 +353,7 @@ struct GPUHistMakerDevice {
d_node_hist_subtraction[idx] =
d_node_hist_parent[idx] - d_node_hist_histogram[idx];
});
}
bool CanDoSubtractionTrick(int nidx_parent, int nidx_histogram, int nidx_subtraction) {
// Make sure histograms are already allocated
hist.AllocateHistogram(nidx_subtraction);
return hist.HistogramExists(nidx_histogram) && hist.HistogramExists(nidx_parent);
return true;
}
void UpdatePosition(const GPUExpandEntry &e, RegTree* p_tree) {
@@ -505,13 +513,15 @@ struct GPUHistMakerDevice {
row_partitioner.reset();
}
void AllReduceHist(int nidx, dh::AllReducer* reducer) {
// num histograms is the number of contiguous histograms in memory to reduce over
void AllReduceHist(int nidx, dh::AllReducer* reducer, int num_histograms) {
monitor.Start("AllReduce");
auto d_node_hist = hist.GetNodeHistogram(nidx).data();
reducer->AllReduceSum(
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
page->Cuts().TotalBins() * (sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT)));
reducer->AllReduceSum(reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
page->Cuts().TotalBins() *
(sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT)) *
num_histograms);
monitor.Stop("AllReduce");
}
@@ -519,33 +529,50 @@ struct GPUHistMakerDevice {
/**
* \brief Build GPU local histograms for the left and right child of some parent node
*/
void BuildHistLeftRight(const GPUExpandEntry &candidate, int nidx_left,
int nidx_right, dh::AllReducer* reducer) {
auto build_hist_nidx = nidx_left;
auto subtraction_trick_nidx = nidx_right;
void BuildHistLeftRight(std::vector<GPUExpandEntry> const& candidates, dh::AllReducer* reducer,
const RegTree& tree) {
if (candidates.empty()) return;
// Some nodes we will manually compute histograms
// others we will do by subtraction
std::vector<int> hist_nidx;
std::vector<int> subtraction_nidx;
for (auto& e : candidates) {
// Decide whether to build the left histogram or right histogram
// Use sum of Hessian as a heuristic to select node with fewest training instances
bool fewer_right = e.split.right_sum.GetHess() < e.split.left_sum.GetHess();
if (fewer_right) {
hist_nidx.emplace_back(tree[e.nid].RightChild());
subtraction_nidx.emplace_back(tree[e.nid].LeftChild());
} else {
hist_nidx.emplace_back(tree[e.nid].LeftChild());
subtraction_nidx.emplace_back(tree[e.nid].RightChild());
}
}
std::vector<int> all_new = hist_nidx;
all_new.insert(all_new.end(), subtraction_nidx.begin(), subtraction_nidx.end());
// Allocate the histograms
// Guaranteed contiguous memory
hist.AllocateHistograms(all_new);
// Decide whether to build the left histogram or right histogram
// Use sum of Hessian as a heuristic to select node with fewest training instances
bool fewer_right = candidate.split.right_sum.GetHess() < candidate.split.left_sum.GetHess();
if (fewer_right) {
std::swap(build_hist_nidx, subtraction_trick_nidx);
for (auto nidx : hist_nidx) {
this->BuildHist(nidx);
}
this->BuildHist(build_hist_nidx);
this->AllReduceHist(build_hist_nidx, reducer);
// Reduce all in one go
// This gives much better latency in a distributed setting
// when processing a large batch
this->AllReduceHist(hist_nidx.at(0), reducer, hist_nidx.size());
// Check whether we can use the subtraction trick to calculate the other
bool do_subtraction_trick = this->CanDoSubtractionTrick(
candidate.nid, build_hist_nidx, subtraction_trick_nidx);
for (int i = 0; i < subtraction_nidx.size(); i++) {
auto build_hist_nidx = hist_nidx.at(i);
auto subtraction_trick_nidx = subtraction_nidx.at(i);
auto parent_nidx = candidates.at(i).nid;
if (do_subtraction_trick) {
// Calculate other histogram using subtraction trick
this->SubtractionTrick(candidate.nid, build_hist_nidx,
subtraction_trick_nidx);
} else {
// Calculate other histogram manually
this->BuildHist(subtraction_trick_nidx);
this->AllReduceHist(subtraction_trick_nidx, reducer);
if (!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)) {
// Calculate other histogram manually
this->BuildHist(subtraction_trick_nidx);
this->AllReduceHist(subtraction_trick_nidx, reducer, 1);
}
}
}
@@ -605,8 +632,9 @@ struct GPUHistMakerDevice {
GradientPairPrecise{}, thrust::plus<GradientPairPrecise>{});
rabit::Allreduce<rabit::op::Sum, double>(reinterpret_cast<double*>(&root_sum), 2);
hist.AllocateHistograms({kRootNIdx});
this->BuildHist(kRootNIdx);
this->AllReduceHist(kRootNIdx, reducer);
this->AllReduceHist(kRootNIdx, reducer, 1);
// Remember root stats
node_sum_gradients[kRootNIdx] = root_sum;
@@ -624,7 +652,8 @@ struct GPUHistMakerDevice {
RegTree* p_tree, dh::AllReducer* reducer,
HostDeviceVector<bst_node_t>* p_out_position) {
auto& tree = *p_tree;
Driver<GPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param.grow_policy));
// Process maximum 32 nodes at a time
Driver<GPUExpandEntry> driver(param, 32);
monitor.Start("Reset");
this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_);
@@ -634,48 +663,44 @@ struct GPUHistMakerDevice {
driver.Push({ this->InitRoot(p_tree, reducer) });
monitor.Stop("InitRoot");
auto num_leaves = 1;
// The set of leaves that can be expanded asynchronously
auto expand_set = driver.Pop();
while (!expand_set.empty()) {
auto new_candidates =
pinned.GetSpan<GPUExpandEntry>(expand_set.size() * 2, GPUExpandEntry());
for (auto i = 0ull; i < expand_set.size(); i++) {
auto candidate = expand_set.at(i);
if (!candidate.IsValid(param, num_leaves)) {
continue;
}
for (auto& candidate : expand_set) {
this->ApplySplit(candidate, p_tree);
}
// Get the candidates we are allowed to expand further
// e.g. We do not bother further processing nodes whose children are beyond max depth
std::vector<GPUExpandEntry> filtered_expand_set;
std::copy_if(expand_set.begin(), expand_set.end(), std::back_inserter(filtered_expand_set),
[&](const auto& e) { return driver.IsChildValid(e); });
num_leaves++;
auto new_candidates =
pinned.GetSpan<GPUExpandEntry>(filtered_expand_set.size() * 2, GPUExpandEntry());
for (const auto& e : filtered_expand_set) {
monitor.Start("UpdatePosition");
// Update position is only run when child is valid, instead of right after apply
// split (as in approx tree method). Hense we have the finalise position call
// in GPU Hist.
this->UpdatePosition(e, p_tree);
monitor.Stop("UpdatePosition");
}
monitor.Start("BuildHist");
this->BuildHistLeftRight(filtered_expand_set, reducer, tree);
monitor.Stop("BuildHist");
for (auto i = 0ull; i < filtered_expand_set.size(); i++) {
auto candidate = filtered_expand_set.at(i);
int left_child_nidx = tree[candidate.nid].LeftChild();
int right_child_nidx = tree[candidate.nid].RightChild();
// Only create child entries if needed_
if (GPUExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx),
num_leaves)) {
monitor.Start("UpdatePosition");
// Update position is only run when child is valid, instead of right after apply
// split (as in approx tree method). Hense we have the finalise position call
// in GPU Hist.
this->UpdatePosition(candidate, p_tree);
monitor.Stop("UpdatePosition");
monitor.Start("BuildHist");
this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer);
monitor.Stop("BuildHist");
monitor.Start("EvaluateSplits");
this->EvaluateLeftRightSplits(candidate, left_child_nidx, right_child_nidx, *p_tree,
new_candidates.subspan(i * 2, 2));
monitor.Stop("EvaluateSplits");
} else {
// Set default
new_candidates[i * 2] = GPUExpandEntry();
new_candidates[i * 2 + 1] = GPUExpandEntry();
}
monitor.Start("EvaluateSplits");
this->EvaluateLeftRightSplits(candidate, left_child_nidx, right_child_nidx, *p_tree,
new_candidates.subspan(i * 2, 2));
monitor.Stop("EvaluateSplits");
}
dh::DefaultStream().Sync();
driver.Push(new_candidates.begin(), new_candidates.end());