Fuse gpu_hist all-reduce calls where possible (#7867)

This commit is contained in:
Rory Mitchell 2022-05-17 13:27:50 +02:00 committed by GitHub
parent b41cf92dc2
commit 71d3b2e036
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 234 additions and 185 deletions

View File

@ -33,10 +33,11 @@ class Driver {
std::function<bool(ExpandEntryT, ExpandEntryT)>>; std::function<bool(ExpandEntryT, ExpandEntryT)>>;
public: public:
explicit Driver(TrainParam::TreeGrowPolicy policy) explicit Driver(TrainParam param, std::size_t max_node_batch_size = 256)
: policy_(policy), : param_(param),
queue_(policy == TrainParam::kDepthWise ? DepthWise<ExpandEntryT> : max_node_batch_size_(max_node_batch_size),
LossGuide<ExpandEntryT>) {} queue_(param.grow_policy == TrainParam::kDepthWise ? DepthWise<ExpandEntryT>
: LossGuide<ExpandEntryT>) {}
template <typename EntryIterT> template <typename EntryIterT>
void Push(EntryIterT begin, EntryIterT end) { void Push(EntryIterT begin, EntryIterT end) {
for (auto it = begin; it != end; ++it) { for (auto it = begin; it != end; ++it) {
@ -55,24 +56,42 @@ class Driver {
return queue_.empty(); return queue_.empty();
} }
// Can a child of this entry still be expanded?
// can be used to avoid extra work
bool IsChildValid(ExpandEntryT const& parent_entry) {
if (param_.max_depth > 0 && parent_entry.depth + 1 >= param_.max_depth) return false;
if (param_.max_leaves > 0 && num_leaves_ >= param_.max_leaves) return false;
return true;
}
// Return the set of nodes to be expanded // Return the set of nodes to be expanded
// This set has no dependencies between entries so they may be expanded in // This set has no dependencies between entries so they may be expanded in
// parallel or asynchronously // parallel or asynchronously
std::vector<ExpandEntryT> Pop() { std::vector<ExpandEntryT> Pop() {
if (queue_.empty()) return {}; if (queue_.empty()) return {};
// Return a single entry for loss guided mode // Return a single entry for loss guided mode
if (policy_ == TrainParam::kLossGuide) { if (param_.grow_policy == TrainParam::kLossGuide) {
ExpandEntryT e = queue_.top(); ExpandEntryT e = queue_.top();
queue_.pop(); queue_.pop();
return {e};
if (e.IsValid(param_, num_leaves_)) {
num_leaves_++;
return {e};
} else {
return {};
}
} }
// Return nodes on same level for depth wise // Return nodes on same level for depth wise
std::vector<ExpandEntryT> result; std::vector<ExpandEntryT> result;
ExpandEntryT e = queue_.top(); ExpandEntryT e = queue_.top();
int level = e.depth; int level = e.depth;
while (e.depth == level && !queue_.empty()) { while (e.depth == level && !queue_.empty() && result.size() < max_node_batch_size_) {
queue_.pop(); queue_.pop();
result.emplace_back(e); if (e.IsValid(param_, num_leaves_)) {
num_leaves_++;
result.emplace_back(e);
}
if (!queue_.empty()) { if (!queue_.empty()) {
e = queue_.top(); e = queue_.top();
} }
@ -81,7 +100,9 @@ class Driver {
} }
private: private:
TrainParam::TreeGrowPolicy policy_; TrainParam param_;
std::size_t num_leaves_ = 1;
std::size_t max_node_batch_size_;
ExpandQueue queue_; ExpandQueue queue_;
}; };
} // namespace tree } // namespace tree

View File

@ -103,7 +103,7 @@ class GPUHistEvaluator {
} }
/** /**
* \brief Get sorted index storage based on the left node of inputs . * \brief Get sorted index storage based on the left node of inputs.
*/ */
auto SortedIdx(EvaluateSplitInputs<GradientSumT> left) { auto SortedIdx(EvaluateSplitInputs<GradientSumT> left) {
if (left.nidx == RegTree::kRoot && !cat_sorted_idx_.empty()) { if (left.nidx == RegTree::kRoot && !cat_sorted_idx_.empty()) {

View File

@ -247,15 +247,6 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
dh::safe_cuda(cudaGetLastError()); dh::safe_cuda(cudaGetLastError());
} }
template void BuildGradientHistogram<GradientPair>(
EllpackDeviceAccessor const& matrix,
FeatureGroupsAccessor const& feature_groups,
common::Span<GradientPair const> gpair,
common::Span<const uint32_t> ridx,
common::Span<GradientPair> histogram,
HistRounding<GradientPair> rounding,
bool force_global_memory);
template void BuildGradientHistogram<GradientPairPrecise>( template void BuildGradientHistogram<GradientPairPrecise>(
EllpackDeviceAccessor const& matrix, EllpackDeviceAccessor const& matrix,
FeatureGroupsAccessor const& feature_groups, FeatureGroupsAccessor const& feature_groups,

View File

@ -179,10 +179,9 @@ class GloablApproxBuilder {
p_last_tree_ = p_tree; p_last_tree_ = p_tree;
this->InitData(p_fmat, hess); this->InitData(p_fmat, hess);
Driver<CPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param_.grow_policy)); Driver<CPUExpandEntry> driver(param_);
auto &tree = *p_tree; auto &tree = *p_tree;
driver.Push({this->InitRoot(p_fmat, gpair, hess, p_tree)}); driver.Push({this->InitRoot(p_fmat, gpair, hess, p_tree)});
bst_node_t num_leaves{1};
auto expand_set = driver.Pop(); auto expand_set = driver.Pop();
/** /**
@ -201,14 +200,9 @@ class GloablApproxBuilder {
// candidates that can be applied. // candidates that can be applied.
std::vector<CPUExpandEntry> applied; std::vector<CPUExpandEntry> applied;
for (auto const &candidate : expand_set) { for (auto const &candidate : expand_set) {
if (!candidate.IsValid(param_, num_leaves)) {
continue;
}
evaluator_.ApplyTreeSplit(candidate, p_tree); evaluator_.ApplyTreeSplit(candidate, p_tree);
applied.push_back(candidate); applied.push_back(candidate);
num_leaves++; if (driver.IsChildValid(candidate)) {
int left_child_nidx = tree[candidate.nid].LeftChild();
if (CPUExpandEntry::ChildIsValid(param_, p_tree->GetDepth(left_child_nidx), num_leaves)) {
valid_candidates.emplace_back(candidate); valid_candidates.emplace_back(candidate);
} }
} }

View File

@ -62,7 +62,7 @@ DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam);
#endif // !defined(GTEST_TEST) #endif // !defined(GTEST_TEST)
/** /**
* \struct DeviceHistogram * \struct DeviceHistogramStorage
* *
* \summary Data storage for node histograms on device. Automatically expands. * \summary Data storage for node histograms on device. Automatically expands.
* *
@ -72,20 +72,27 @@ DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam);
* \author Rory * \author Rory
* \date 28/07/2018 * \date 28/07/2018
*/ */
template <typename GradientSumT, size_t kStopGrowingSize = 1 << 26> template <typename GradientSumT, size_t kStopGrowingSize = 1 << 28>
class DeviceHistogram { class DeviceHistogramStorage {
private: private:
/*! \brief Map nidx to starting index of its histogram. */ /*! \brief Map nidx to starting index of its histogram. */
std::map<int, size_t> nidx_map_; std::map<int, size_t> nidx_map_;
// Large buffer of zeroed memory, caches histograms
dh::device_vector<typename GradientSumT::ValueT> data_; dh::device_vector<typename GradientSumT::ValueT> data_;
// If we run out of storage allocate one histogram at a time
// in overflow. Not cached, overwritten when a new histogram
// is requested
dh::device_vector<typename GradientSumT::ValueT> overflow_;
std::map<int, size_t> overflow_nidx_map_;
int n_bins_; int n_bins_;
int device_id_; int device_id_;
static constexpr size_t kNumItemsInGradientSum = static constexpr size_t kNumItemsInGradientSum =
sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT); sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT);
static_assert(kNumItemsInGradientSum == 2, static_assert(kNumItemsInGradientSum == 2, "Number of items in gradient type should be 2.");
"Number of items in gradient type should be 2.");
public: public:
// Start with about 16mb
DeviceHistogramStorage() { data_.reserve(1 << 22); }
void Init(int device_id, int n_bins) { void Init(int device_id, int n_bins) {
this->n_bins_ = n_bins; this->n_bins_ = n_bins;
this->device_id_ = device_id; this->device_id_ = device_id;
@ -93,52 +100,47 @@ class DeviceHistogram {
void Reset() { void Reset() {
auto d_data = data_.data().get(); auto d_data = data_.data().get();
dh::LaunchN(data_.size(), dh::LaunchN(data_.size(), [=] __device__(size_t idx) { d_data[idx] = 0.0f; });
[=] __device__(size_t idx) { d_data[idx] = 0.0f; });
nidx_map_.clear(); nidx_map_.clear();
overflow_nidx_map_.clear();
} }
bool HistogramExists(int nidx) const { bool HistogramExists(int nidx) const {
return nidx_map_.find(nidx) != nidx_map_.cend(); return nidx_map_.find(nidx) != nidx_map_.cend() ||
} overflow_nidx_map_.find(nidx) != overflow_nidx_map_.cend();
int Bins() const {
return n_bins_;
}
size_t HistogramSize() const {
return n_bins_ * kNumItemsInGradientSum;
} }
int Bins() const { return n_bins_; }
size_t HistogramSize() const { return n_bins_ * kNumItemsInGradientSum; }
dh::device_vector<typename GradientSumT::ValueT>& Data() { return data_; }
dh::device_vector<typename GradientSumT::ValueT>& Data() { void AllocateHistograms(const std::vector<int>& new_nidxs) {
return data_; for (int nidx : new_nidxs) {
} CHECK(!HistogramExists(nidx));
}
void AllocateHistogram(int nidx) {
if (HistogramExists(nidx)) return;
// Number of items currently used in data // Number of items currently used in data
const size_t used_size = nidx_map_.size() * HistogramSize(); const size_t used_size = nidx_map_.size() * HistogramSize();
const size_t new_used_size = used_size + HistogramSize(); const size_t new_used_size = used_size + HistogramSize() * new_nidxs.size();
if (data_.size() >= kStopGrowingSize) { if (used_size >= kStopGrowingSize) {
// Recycle histogram memory // Use overflow
if (new_used_size <= data_.size()) { // Delete previous entries
// no need to remove old node, just insert the new one. overflow_nidx_map_.clear();
nidx_map_[nidx] = used_size; overflow_.resize(HistogramSize() * new_nidxs.size());
// memset histogram size in bytes // Zero memory
} else { auto d_data = overflow_.data().get();
std::pair<int, size_t> old_entry = *nidx_map_.begin(); dh::LaunchN(overflow_.size(),
nidx_map_.erase(old_entry.first); [=] __device__(size_t idx) { d_data[idx] = 0.0; });
nidx_map_[nidx] = old_entry.second; // Append new histograms
for (int nidx : new_nidxs) {
overflow_nidx_map_[nidx] = overflow_nidx_map_.size() * HistogramSize();
} }
// Zero recycled memory
auto d_data = data_.data().get() + nidx_map_[nidx];
dh::LaunchN(n_bins_ * 2,
[=] __device__(size_t idx) { d_data[idx] = 0.0f; });
} else { } else {
// Append new node histogram CHECK_GE(data_.size(), used_size);
nidx_map_[nidx] = used_size; // Expand if necessary
// Check there is enough memory for another histogram node if (data_.size() < new_used_size) {
if (data_.size() < new_used_size + HistogramSize()) { data_.resize(std::max(data_.size() * 2, new_used_size));
size_t new_required_memory = }
std::max(data_.size() * 2, HistogramSize()); // Append new histograms
data_.resize(new_required_memory); for (int nidx : new_nidxs) {
nidx_map_[nidx] = nidx_map_.size() * HistogramSize();
} }
} }
@ -152,9 +154,16 @@ class DeviceHistogram {
*/ */
common::Span<GradientSumT> GetNodeHistogram(int nidx) { common::Span<GradientSumT> GetNodeHistogram(int nidx) {
CHECK(this->HistogramExists(nidx)); CHECK(this->HistogramExists(nidx));
auto ptr = data_.data().get() + nidx_map_.at(nidx);
return common::Span<GradientSumT>( if (nidx_map_.find(nidx) != nidx_map_.cend()) {
reinterpret_cast<GradientSumT*>(ptr), n_bins_); // Fetch from normal cache
auto ptr = data_.data().get() + nidx_map_.at(nidx);
return common::Span<GradientSumT>(reinterpret_cast<GradientSumT*>(ptr), n_bins_);
} else {
// Fetch from overflow
auto ptr = overflow_.data().get() + overflow_nidx_map_.at(nidx);
return common::Span<GradientSumT>(reinterpret_cast<GradientSumT*>(ptr), n_bins_);
}
} }
}; };
@ -171,7 +180,7 @@ struct GPUHistMakerDevice {
BatchParam batch_param; BatchParam batch_param;
std::unique_ptr<RowPartitioner> row_partitioner; std::unique_ptr<RowPartitioner> row_partitioner;
DeviceHistogram<GradientSumT> hist{}; DeviceHistogramStorage<GradientSumT> hist{};
dh::caching_device_vector<GradientPair> d_gpair; // storage for gpair; dh::caching_device_vector<GradientPair> d_gpair; // storage for gpair;
common::Span<GradientPair> gpair; common::Span<GradientPair> gpair;
@ -195,6 +204,7 @@ struct GPUHistMakerDevice {
std::unique_ptr<FeatureGroups> feature_groups; std::unique_ptr<FeatureGroups> feature_groups;
GPUHistMakerDevice(Context const* ctx, EllpackPageImpl const* _page, GPUHistMakerDevice(Context const* ctx, EllpackPageImpl const* _page,
common::Span<FeatureType const> _feature_types, bst_uint _n_rows, common::Span<FeatureType const> _feature_types, bst_uint _n_rows,
TrainParam _param, uint32_t column_sampler_seed, uint32_t n_features, TrainParam _param, uint32_t column_sampler_seed, uint32_t n_features,
@ -322,7 +332,6 @@ struct GPUHistMakerDevice {
} }
void BuildHist(int nidx) { void BuildHist(int nidx) {
hist.AllocateHistogram(nidx);
auto d_node_hist = hist.GetNodeHistogram(nidx); auto d_node_hist = hist.GetNodeHistogram(nidx);
auto d_ridx = row_partitioner->GetRows(nidx); auto d_ridx = row_partitioner->GetRows(nidx);
BuildGradientHistogram(page->GetDeviceAccessor(ctx_->gpu_id), BuildGradientHistogram(page->GetDeviceAccessor(ctx_->gpu_id),
@ -330,8 +339,12 @@ struct GPUHistMakerDevice {
d_ridx, d_node_hist, histogram_rounding); d_ridx, d_node_hist, histogram_rounding);
} }
void SubtractionTrick(int nidx_parent, int nidx_histogram, // Attempt to do subtraction trick
int nidx_subtraction) { // return true if succeeded
bool SubtractionTrick(int nidx_parent, int nidx_histogram, int nidx_subtraction) {
if (!hist.HistogramExists(nidx_histogram) || !hist.HistogramExists(nidx_parent)) {
return false;
}
auto d_node_hist_parent = hist.GetNodeHistogram(nidx_parent); auto d_node_hist_parent = hist.GetNodeHistogram(nidx_parent);
auto d_node_hist_histogram = hist.GetNodeHistogram(nidx_histogram); auto d_node_hist_histogram = hist.GetNodeHistogram(nidx_histogram);
auto d_node_hist_subtraction = hist.GetNodeHistogram(nidx_subtraction); auto d_node_hist_subtraction = hist.GetNodeHistogram(nidx_subtraction);
@ -340,12 +353,7 @@ struct GPUHistMakerDevice {
d_node_hist_subtraction[idx] = d_node_hist_subtraction[idx] =
d_node_hist_parent[idx] - d_node_hist_histogram[idx]; d_node_hist_parent[idx] - d_node_hist_histogram[idx];
}); });
} return true;
bool CanDoSubtractionTrick(int nidx_parent, int nidx_histogram, int nidx_subtraction) {
// Make sure histograms are already allocated
hist.AllocateHistogram(nidx_subtraction);
return hist.HistogramExists(nidx_histogram) && hist.HistogramExists(nidx_parent);
} }
void UpdatePosition(const GPUExpandEntry &e, RegTree* p_tree) { void UpdatePosition(const GPUExpandEntry &e, RegTree* p_tree) {
@ -505,13 +513,15 @@ struct GPUHistMakerDevice {
row_partitioner.reset(); row_partitioner.reset();
} }
void AllReduceHist(int nidx, dh::AllReducer* reducer) { // num histograms is the number of contiguous histograms in memory to reduce over
void AllReduceHist(int nidx, dh::AllReducer* reducer, int num_histograms) {
monitor.Start("AllReduce"); monitor.Start("AllReduce");
auto d_node_hist = hist.GetNodeHistogram(nidx).data(); auto d_node_hist = hist.GetNodeHistogram(nidx).data();
reducer->AllReduceSum( reducer->AllReduceSum(reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist), reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist), page->Cuts().TotalBins() *
page->Cuts().TotalBins() * (sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT))); (sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT)) *
num_histograms);
monitor.Stop("AllReduce"); monitor.Stop("AllReduce");
} }
@ -519,33 +529,50 @@ struct GPUHistMakerDevice {
/** /**
* \brief Build GPU local histograms for the left and right child of some parent node * \brief Build GPU local histograms for the left and right child of some parent node
*/ */
void BuildHistLeftRight(const GPUExpandEntry &candidate, int nidx_left, void BuildHistLeftRight(std::vector<GPUExpandEntry> const& candidates, dh::AllReducer* reducer,
int nidx_right, dh::AllReducer* reducer) { const RegTree& tree) {
auto build_hist_nidx = nidx_left; if (candidates.empty()) return;
auto subtraction_trick_nidx = nidx_right; // Some nodes we will manually compute histograms
// others we will do by subtraction
std::vector<int> hist_nidx;
std::vector<int> subtraction_nidx;
for (auto& e : candidates) {
// Decide whether to build the left histogram or right histogram
// Use sum of Hessian as a heuristic to select node with fewest training instances
bool fewer_right = e.split.right_sum.GetHess() < e.split.left_sum.GetHess();
if (fewer_right) {
hist_nidx.emplace_back(tree[e.nid].RightChild());
subtraction_nidx.emplace_back(tree[e.nid].LeftChild());
} else {
hist_nidx.emplace_back(tree[e.nid].LeftChild());
subtraction_nidx.emplace_back(tree[e.nid].RightChild());
}
}
std::vector<int> all_new = hist_nidx;
all_new.insert(all_new.end(), subtraction_nidx.begin(), subtraction_nidx.end());
// Allocate the histograms
// Guaranteed contiguous memory
hist.AllocateHistograms(all_new);
// Decide whether to build the left histogram or right histogram for (auto nidx : hist_nidx) {
// Use sum of Hessian as a heuristic to select node with fewest training instances this->BuildHist(nidx);
bool fewer_right = candidate.split.right_sum.GetHess() < candidate.split.left_sum.GetHess();
if (fewer_right) {
std::swap(build_hist_nidx, subtraction_trick_nidx);
} }
this->BuildHist(build_hist_nidx); // Reduce all in one go
this->AllReduceHist(build_hist_nidx, reducer); // This gives much better latency in a distributed setting
// when processing a large batch
this->AllReduceHist(hist_nidx.at(0), reducer, hist_nidx.size());
// Check whether we can use the subtraction trick to calculate the other for (int i = 0; i < subtraction_nidx.size(); i++) {
bool do_subtraction_trick = this->CanDoSubtractionTrick( auto build_hist_nidx = hist_nidx.at(i);
candidate.nid, build_hist_nidx, subtraction_trick_nidx); auto subtraction_trick_nidx = subtraction_nidx.at(i);
auto parent_nidx = candidates.at(i).nid;
if (do_subtraction_trick) { if (!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)) {
// Calculate other histogram using subtraction trick // Calculate other histogram manually
this->SubtractionTrick(candidate.nid, build_hist_nidx, this->BuildHist(subtraction_trick_nidx);
subtraction_trick_nidx); this->AllReduceHist(subtraction_trick_nidx, reducer, 1);
} else { }
// Calculate other histogram manually
this->BuildHist(subtraction_trick_nidx);
this->AllReduceHist(subtraction_trick_nidx, reducer);
} }
} }
@ -605,8 +632,9 @@ struct GPUHistMakerDevice {
GradientPairPrecise{}, thrust::plus<GradientPairPrecise>{}); GradientPairPrecise{}, thrust::plus<GradientPairPrecise>{});
rabit::Allreduce<rabit::op::Sum, double>(reinterpret_cast<double*>(&root_sum), 2); rabit::Allreduce<rabit::op::Sum, double>(reinterpret_cast<double*>(&root_sum), 2);
hist.AllocateHistograms({kRootNIdx});
this->BuildHist(kRootNIdx); this->BuildHist(kRootNIdx);
this->AllReduceHist(kRootNIdx, reducer); this->AllReduceHist(kRootNIdx, reducer, 1);
// Remember root stats // Remember root stats
node_sum_gradients[kRootNIdx] = root_sum; node_sum_gradients[kRootNIdx] = root_sum;
@ -624,7 +652,8 @@ struct GPUHistMakerDevice {
RegTree* p_tree, dh::AllReducer* reducer, RegTree* p_tree, dh::AllReducer* reducer,
HostDeviceVector<bst_node_t>* p_out_position) { HostDeviceVector<bst_node_t>* p_out_position) {
auto& tree = *p_tree; auto& tree = *p_tree;
Driver<GPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param.grow_policy)); // Process maximum 32 nodes at a time
Driver<GPUExpandEntry> driver(param, 32);
monitor.Start("Reset"); monitor.Start("Reset");
this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_); this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_);
@ -634,48 +663,44 @@ struct GPUHistMakerDevice {
driver.Push({ this->InitRoot(p_tree, reducer) }); driver.Push({ this->InitRoot(p_tree, reducer) });
monitor.Stop("InitRoot"); monitor.Stop("InitRoot");
auto num_leaves = 1;
// The set of leaves that can be expanded asynchronously // The set of leaves that can be expanded asynchronously
auto expand_set = driver.Pop(); auto expand_set = driver.Pop();
while (!expand_set.empty()) { while (!expand_set.empty()) {
auto new_candidates = for (auto& candidate : expand_set) {
pinned.GetSpan<GPUExpandEntry>(expand_set.size() * 2, GPUExpandEntry());
for (auto i = 0ull; i < expand_set.size(); i++) {
auto candidate = expand_set.at(i);
if (!candidate.IsValid(param, num_leaves)) {
continue;
}
this->ApplySplit(candidate, p_tree); this->ApplySplit(candidate, p_tree);
}
// Get the candidates we are allowed to expand further
// e.g. We do not bother further processing nodes whose children are beyond max depth
std::vector<GPUExpandEntry> filtered_expand_set;
std::copy_if(expand_set.begin(), expand_set.end(), std::back_inserter(filtered_expand_set),
[&](const auto& e) { return driver.IsChildValid(e); });
num_leaves++;
auto new_candidates =
pinned.GetSpan<GPUExpandEntry>(filtered_expand_set.size() * 2, GPUExpandEntry());
for (const auto& e : filtered_expand_set) {
monitor.Start("UpdatePosition");
// Update position is only run when child is valid, instead of right after apply
// split (as in approx tree method). Hense we have the finalise position call
// in GPU Hist.
this->UpdatePosition(e, p_tree);
monitor.Stop("UpdatePosition");
}
monitor.Start("BuildHist");
this->BuildHistLeftRight(filtered_expand_set, reducer, tree);
monitor.Stop("BuildHist");
for (auto i = 0ull; i < filtered_expand_set.size(); i++) {
auto candidate = filtered_expand_set.at(i);
int left_child_nidx = tree[candidate.nid].LeftChild(); int left_child_nidx = tree[candidate.nid].LeftChild();
int right_child_nidx = tree[candidate.nid].RightChild(); int right_child_nidx = tree[candidate.nid].RightChild();
// Only create child entries if needed_
if (GPUExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx),
num_leaves)) {
monitor.Start("UpdatePosition");
// Update position is only run when child is valid, instead of right after apply
// split (as in approx tree method). Hense we have the finalise position call
// in GPU Hist.
this->UpdatePosition(candidate, p_tree);
monitor.Stop("UpdatePosition");
monitor.Start("BuildHist"); monitor.Start("EvaluateSplits");
this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer); this->EvaluateLeftRightSplits(candidate, left_child_nidx, right_child_nidx, *p_tree,
monitor.Stop("BuildHist"); new_candidates.subspan(i * 2, 2));
monitor.Stop("EvaluateSplits");
monitor.Start("EvaluateSplits");
this->EvaluateLeftRightSplits(candidate, left_child_nidx, right_child_nidx, *p_tree,
new_candidates.subspan(i * 2, 2));
monitor.Stop("EvaluateSplits");
} else {
// Set default
new_candidates[i * 2] = GPUExpandEntry();
new_candidates[i * 2 + 1] = GPUExpandEntry();
}
} }
dh::DefaultStream().Sync(); dh::DefaultStream().Sync();
driver.Push(new_candidates.begin(), new_candidates.end()); driver.Push(new_candidates.begin(), new_candidates.end());

View File

@ -175,10 +175,9 @@ void QuantileHistMaker::Builder::ExpandTree(DMatrix *p_fmat, RegTree *p_tree,
HostDeviceVector<bst_node_t> *p_out_position) { HostDeviceVector<bst_node_t> *p_out_position) {
monitor_->Start(__func__); monitor_->Start(__func__);
Driver<CPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param_.grow_policy)); Driver<CPUExpandEntry> driver(param_);
driver.Push(this->InitRoot(p_fmat, p_tree, gpair_h)); driver.Push(this->InitRoot(p_fmat, p_tree, gpair_h));
auto const &tree = *p_tree; auto const &tree = *p_tree;
bst_node_t num_leaves{1};
auto expand_set = driver.Pop(); auto expand_set = driver.Pop();
while (!expand_set.empty()) { while (!expand_set.empty()) {
@ -188,13 +187,9 @@ void QuantileHistMaker::Builder::ExpandTree(DMatrix *p_fmat, RegTree *p_tree,
std::vector<CPUExpandEntry> applied; std::vector<CPUExpandEntry> applied;
int32_t depth = expand_set.front().depth + 1; int32_t depth = expand_set.front().depth + 1;
for (auto const& candidate : expand_set) { for (auto const& candidate : expand_set) {
if (!candidate.IsValid(param_, num_leaves)) {
continue;
}
evaluator_->ApplyTreeSplit(candidate, p_tree); evaluator_->ApplyTreeSplit(candidate, p_tree);
applied.push_back(candidate); applied.push_back(candidate);
num_leaves++; if (driver.IsChildValid(candidate)) {
if (CPUExpandEntry::ChildIsValid(param_, depth, num_leaves)) {
valid_candidates.emplace_back(candidate); valid_candidates.emplace_back(candidate);
} }
} }

View File

@ -6,41 +6,58 @@ namespace xgboost {
namespace tree { namespace tree {
TEST(GpuHist, DriverDepthWise) { TEST(GpuHist, DriverDepthWise) {
Driver<GPUExpandEntry> driver(TrainParam::kDepthWise); TrainParam p;
p.InitAllowUnknown(Args{});
p.grow_policy = TrainParam::kDepthWise;
Driver<GPUExpandEntry> driver(p, 2);
EXPECT_TRUE(driver.Pop().empty()); EXPECT_TRUE(driver.Pop().empty());
DeviceSplitCandidate split; DeviceSplitCandidate split;
split.loss_chg = 1.0f; split.loss_chg = 1.0f;
GPUExpandEntry root(0, 0, split, .0f, .0f, .0f); split.left_sum = {0.0f, 1.0f};
split.right_sum = {0.0f, 1.0f};
GPUExpandEntry root(0, 0, split, 2.0f, 1.0f, 1.0f);
driver.Push({root}); driver.Push({root});
EXPECT_EQ(driver.Pop().front().nid, 0); EXPECT_EQ(driver.Pop().front().nid, 0);
driver.Push({GPUExpandEntry{1, 1, split, .0f, .0f, .0f}}); driver.Push({GPUExpandEntry{1, 1, split, 2.0f, 1.0f, 1.0f}});
driver.Push({GPUExpandEntry{2, 1, split, .0f, .0f, .0f}}); driver.Push({GPUExpandEntry{2, 1, split, 2.0f, 1.0f, 1.0f}});
driver.Push({GPUExpandEntry{3, 2, split, .0f, .0f, .0f}}); driver.Push({GPUExpandEntry{3, 1, split, 2.0f, 1.0f, 1.0f}});
// Should return entries from level 1 driver.Push({GPUExpandEntry{4, 2, split, 2.0f, 1.0f, 1.0f}});
// Should return 2 entries from level 1
// as we limited the driver to pop maximum 2 nodes
auto res = driver.Pop(); auto res = driver.Pop();
EXPECT_EQ(res.size(), 2); EXPECT_EQ(res.size(), 2);
for (auto &e : res) { for (auto &e : res) {
EXPECT_EQ(e.depth, 1); EXPECT_EQ(e.depth, 1);
} }
// Should now return 1 entry from level 1
res = driver.Pop(); res = driver.Pop();
EXPECT_EQ(res[0].depth, 2); EXPECT_EQ(res.size(), 1);
EXPECT_EQ(res.at(0).depth, 1);
res = driver.Pop();
EXPECT_EQ(res.at(0).depth, 2);
EXPECT_TRUE(driver.Pop().empty()); EXPECT_TRUE(driver.Pop().empty());
} }
TEST(GpuHist, DriverLossGuided) { TEST(GpuHist, DriverLossGuided) {
DeviceSplitCandidate high_gain; DeviceSplitCandidate high_gain;
high_gain.left_sum = {0.0f, 1.0f};
high_gain.right_sum = {0.0f, 1.0f};
high_gain.loss_chg = 5.0f; high_gain.loss_chg = 5.0f;
DeviceSplitCandidate low_gain; DeviceSplitCandidate low_gain = high_gain;
low_gain.loss_chg = 1.0f; low_gain.loss_chg = 1.0f;
Driver<GPUExpandEntry> driver(TrainParam::kLossGuide); TrainParam p;
p.grow_policy=TrainParam::kLossGuide;
Driver<GPUExpandEntry> driver(p);
EXPECT_TRUE(driver.Pop().empty()); EXPECT_TRUE(driver.Pop().empty());
GPUExpandEntry root(0, 0, high_gain, .0f, .0f, .0f); GPUExpandEntry root(0, 0, high_gain, 2.0f, 1.0f, 1.0f );
driver.Push({root}); driver.Push({root});
EXPECT_EQ(driver.Pop().front().nid, 0); EXPECT_EQ(driver.Pop().front().nid, 0);
// Select high gain first // Select high gain first
driver.Push({GPUExpandEntry{1, 1, low_gain, .0f, .0f, .0f}}); driver.Push({GPUExpandEntry{1, 1, low_gain, 2.0f, 1.0f, 1.0f}});
driver.Push({GPUExpandEntry{2, 2, high_gain, .0f, .0f, .0f}}); driver.Push({GPUExpandEntry{2, 2, high_gain, 2.0f, 1.0f, 1.0f}});
auto res = driver.Pop(); auto res = driver.Pop();
EXPECT_EQ(res.size(), 1); EXPECT_EQ(res.size(), 1);
EXPECT_EQ(res[0].nid, 2); EXPECT_EQ(res[0].nid, 2);
@ -49,8 +66,8 @@ TEST(GpuHist, DriverLossGuided) {
EXPECT_EQ(res[0].nid, 1); EXPECT_EQ(res[0].nid, 1);
// If equal gain, use nid // If equal gain, use nid
driver.Push({GPUExpandEntry{2, 1, low_gain, .0f, .0f, .0f}}); driver.Push({GPUExpandEntry{2, 1, low_gain, 2.0f, 1.0f, 1.0f}});
driver.Push({GPUExpandEntry{1, 1, low_gain, .0f, .0f, .0f}}); driver.Push({GPUExpandEntry{1, 1, low_gain, 2.0f, 1.0f, 1.0f}});
res = driver.Pop(); res = driver.Pop();
EXPECT_EQ(res[0].nid, 1); EXPECT_EQ(res[0].nid, 1);
res = driver.Pop(); res = driver.Pop();

View File

@ -95,7 +95,6 @@ TEST(Histogram, GPUDeterministic) {
std::vector<int> shm_sizes{48 * 1024, 64 * 1024, 160 * 1024}; std::vector<int> shm_sizes{48 * 1024, 64 * 1024, 160 * 1024};
for (bool is_dense : is_dense_array) { for (bool is_dense : is_dense_array) {
for (int shm_size : shm_sizes) { for (int shm_size : shm_sizes) {
TestDeterministicHistogram<GradientPair>(is_dense, shm_size);
TestDeterministicHistogram<GradientPairPrecise>(is_dense, shm_size); TestDeterministicHistogram<GradientPairPrecise>(is_dense, shm_size);
} }
} }

View File

@ -27,31 +27,40 @@ TEST(GpuHist, DeviceHistogram) {
// Ensures that node allocates correctly after reaching `kStopGrowingSize`. // Ensures that node allocates correctly after reaching `kStopGrowingSize`.
dh::safe_cuda(cudaSetDevice(0)); dh::safe_cuda(cudaSetDevice(0));
constexpr size_t kNBins = 128; constexpr size_t kNBins = 128;
constexpr size_t kNNodes = 4; constexpr int kNNodes = 4;
constexpr size_t kStopGrowing = kNNodes * kNBins * 2u; constexpr size_t kStopGrowing = kNNodes * kNBins * 2u;
DeviceHistogram<GradientPairPrecise, kStopGrowing> histogram; DeviceHistogramStorage<GradientPairPrecise, kStopGrowing> histogram;
histogram.Init(0, kNBins); histogram.Init(0, kNBins);
for (size_t i = 0; i < kNNodes; ++i) { for (int i = 0; i < kNNodes; ++i) {
histogram.AllocateHistogram(i); histogram.AllocateHistograms({i});
} }
histogram.Reset(); histogram.Reset();
ASSERT_EQ(histogram.Data().size(), kStopGrowing); ASSERT_EQ(histogram.Data().size(), kStopGrowing);
// Use allocated memory but do not erase nidx_map. // Use allocated memory but do not erase nidx_map.
for (size_t i = 0; i < kNNodes; ++i) { for (int i = 0; i < kNNodes; ++i) {
histogram.AllocateHistogram(i); histogram.AllocateHistograms({i});
} }
for (size_t i = 0; i < kNNodes; ++i) { for (int i = 0; i < kNNodes; ++i) {
ASSERT_TRUE(histogram.HistogramExists(i)); ASSERT_TRUE(histogram.HistogramExists(i));
} }
// Erase existing nidx_map. // Add two new nodes
for (size_t i = kNNodes; i < kNNodes * 2; ++i) { histogram.AllocateHistograms({kNNodes});
histogram.AllocateHistogram(i); histogram.AllocateHistograms({kNNodes + 1});
}
for (size_t i = 0; i < kNNodes; ++i) { // Old cached nodes should still exist
ASSERT_FALSE(histogram.HistogramExists(i)); for (int i = 0; i < kNNodes; ++i) {
ASSERT_TRUE(histogram.HistogramExists(i));
} }
// Should be deleted
ASSERT_FALSE(histogram.HistogramExists(kNNodes));
// Most recent node should exist
ASSERT_TRUE(histogram.HistogramExists(kNNodes + 1));
// Add same node again - should fail
EXPECT_ANY_THROW(histogram.AllocateHistograms({kNNodes + 1}););
} }
std::vector<GradientPairPrecise> GetHostHistGpair() { std::vector<GradientPairPrecise> GetHostHistGpair() {
@ -96,9 +105,9 @@ void TestBuildHist(bool use_shared_memory_histograms) {
thrust::host_vector<common::CompressedByteT> h_gidx_buffer (page->gidx_buffer.HostVector()); thrust::host_vector<common::CompressedByteT> h_gidx_buffer (page->gidx_buffer.HostVector());
maker.row_partitioner.reset(new RowPartitioner(0, kNRows)); maker.row_partitioner.reset(new RowPartitioner(0, kNRows));
maker.hist.AllocateHistogram(0); maker.hist.AllocateHistograms({0});
maker.gpair = gpair.DeviceSpan(); maker.gpair = gpair.DeviceSpan();
maker.histogram_rounding = CreateRoundingFactor<GradientSumT>(maker.gpair);; maker.histogram_rounding = CreateRoundingFactor<GradientSumT>(maker.gpair);
BuildGradientHistogram( BuildGradientHistogram(
page->GetDeviceAccessor(0), maker.feature_groups->DeviceAccessor(0), page->GetDeviceAccessor(0), maker.feature_groups->DeviceAccessor(0),
@ -106,7 +115,7 @@ void TestBuildHist(bool use_shared_memory_histograms) {
maker.hist.GetNodeHistogram(0), maker.histogram_rounding, maker.hist.GetNodeHistogram(0), maker.histogram_rounding,
!use_shared_memory_histograms); !use_shared_memory_histograms);
DeviceHistogram<GradientSumT>& d_hist = maker.hist; DeviceHistogramStorage<GradientSumT>& d_hist = maker.hist;
auto node_histogram = d_hist.GetNodeHistogram(0); auto node_histogram = d_hist.GetNodeHistogram(0);
// d_hist.data stored in float, not gradient pair // d_hist.data stored in float, not gradient pair
@ -129,12 +138,10 @@ void TestBuildHist(bool use_shared_memory_histograms) {
TEST(GpuHist, BuildHistGlobalMem) { TEST(GpuHist, BuildHistGlobalMem) {
TestBuildHist<GradientPairPrecise>(false); TestBuildHist<GradientPairPrecise>(false);
TestBuildHist<GradientPair>(false);
} }
TEST(GpuHist, BuildHistSharedMem) { TEST(GpuHist, BuildHistSharedMem) {
TestBuildHist<GradientPairPrecise>(true); TestBuildHist<GradientPairPrecise>(true);
TestBuildHist<GradientPair>(true);
} }
HistogramCutsWrapper GetHostCutMatrix () { HistogramCutsWrapper GetHostCutMatrix () {
@ -198,7 +205,7 @@ TEST(GpuHist, EvaluateRootSplit) {
// Initialize GPUHistMakerDevice::hist // Initialize GPUHistMakerDevice::hist
maker.hist.Init(0, (max_bins - 1) * kNCols); maker.hist.Init(0, (max_bins - 1) * kNCols);
maker.hist.AllocateHistogram(0); maker.hist.AllocateHistograms({0});
// Each row of hist_gpair represents gpairs for one feature. // Each row of hist_gpair represents gpairs for one feature.
// Each entry represents a bin. // Each entry represents a bin.
std::vector<GradientPairPrecise> hist_gpair = GetHostHistGpair(); std::vector<GradientPairPrecise> hist_gpair = GetHostHistGpair();