Fuse gpu_hist all-reduce calls where possible (#7867)
This commit is contained in:
parent
b41cf92dc2
commit
71d3b2e036
@ -33,10 +33,11 @@ class Driver {
|
||||
std::function<bool(ExpandEntryT, ExpandEntryT)>>;
|
||||
|
||||
public:
|
||||
explicit Driver(TrainParam::TreeGrowPolicy policy)
|
||||
: policy_(policy),
|
||||
queue_(policy == TrainParam::kDepthWise ? DepthWise<ExpandEntryT> :
|
||||
LossGuide<ExpandEntryT>) {}
|
||||
explicit Driver(TrainParam param, std::size_t max_node_batch_size = 256)
|
||||
: param_(param),
|
||||
max_node_batch_size_(max_node_batch_size),
|
||||
queue_(param.grow_policy == TrainParam::kDepthWise ? DepthWise<ExpandEntryT>
|
||||
: LossGuide<ExpandEntryT>) {}
|
||||
template <typename EntryIterT>
|
||||
void Push(EntryIterT begin, EntryIterT end) {
|
||||
for (auto it = begin; it != end; ++it) {
|
||||
@ -55,24 +56,42 @@ class Driver {
|
||||
return queue_.empty();
|
||||
}
|
||||
|
||||
// Can a child of this entry still be expanded?
|
||||
// can be used to avoid extra work
|
||||
bool IsChildValid(ExpandEntryT const& parent_entry) {
|
||||
if (param_.max_depth > 0 && parent_entry.depth + 1 >= param_.max_depth) return false;
|
||||
if (param_.max_leaves > 0 && num_leaves_ >= param_.max_leaves) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Return the set of nodes to be expanded
|
||||
// This set has no dependencies between entries so they may be expanded in
|
||||
// parallel or asynchronously
|
||||
std::vector<ExpandEntryT> Pop() {
|
||||
if (queue_.empty()) return {};
|
||||
// Return a single entry for loss guided mode
|
||||
if (policy_ == TrainParam::kLossGuide) {
|
||||
if (param_.grow_policy == TrainParam::kLossGuide) {
|
||||
ExpandEntryT e = queue_.top();
|
||||
queue_.pop();
|
||||
return {e};
|
||||
|
||||
if (e.IsValid(param_, num_leaves_)) {
|
||||
num_leaves_++;
|
||||
return {e};
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
// Return nodes on same level for depth wise
|
||||
std::vector<ExpandEntryT> result;
|
||||
ExpandEntryT e = queue_.top();
|
||||
int level = e.depth;
|
||||
while (e.depth == level && !queue_.empty()) {
|
||||
while (e.depth == level && !queue_.empty() && result.size() < max_node_batch_size_) {
|
||||
queue_.pop();
|
||||
result.emplace_back(e);
|
||||
if (e.IsValid(param_, num_leaves_)) {
|
||||
num_leaves_++;
|
||||
result.emplace_back(e);
|
||||
}
|
||||
|
||||
if (!queue_.empty()) {
|
||||
e = queue_.top();
|
||||
}
|
||||
@ -81,7 +100,9 @@ class Driver {
|
||||
}
|
||||
|
||||
private:
|
||||
TrainParam::TreeGrowPolicy policy_;
|
||||
TrainParam param_;
|
||||
std::size_t num_leaves_ = 1;
|
||||
std::size_t max_node_batch_size_;
|
||||
ExpandQueue queue_;
|
||||
};
|
||||
} // namespace tree
|
||||
|
||||
@ -103,7 +103,7 @@ class GPUHistEvaluator {
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get sorted index storage based on the left node of inputs .
|
||||
* \brief Get sorted index storage based on the left node of inputs.
|
||||
*/
|
||||
auto SortedIdx(EvaluateSplitInputs<GradientSumT> left) {
|
||||
if (left.nidx == RegTree::kRoot && !cat_sorted_idx_.empty()) {
|
||||
|
||||
@ -247,15 +247,6 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
|
||||
dh::safe_cuda(cudaGetLastError());
|
||||
}
|
||||
|
||||
template void BuildGradientHistogram<GradientPair>(
|
||||
EllpackDeviceAccessor const& matrix,
|
||||
FeatureGroupsAccessor const& feature_groups,
|
||||
common::Span<GradientPair const> gpair,
|
||||
common::Span<const uint32_t> ridx,
|
||||
common::Span<GradientPair> histogram,
|
||||
HistRounding<GradientPair> rounding,
|
||||
bool force_global_memory);
|
||||
|
||||
template void BuildGradientHistogram<GradientPairPrecise>(
|
||||
EllpackDeviceAccessor const& matrix,
|
||||
FeatureGroupsAccessor const& feature_groups,
|
||||
|
||||
@ -179,10 +179,9 @@ class GloablApproxBuilder {
|
||||
p_last_tree_ = p_tree;
|
||||
this->InitData(p_fmat, hess);
|
||||
|
||||
Driver<CPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param_.grow_policy));
|
||||
Driver<CPUExpandEntry> driver(param_);
|
||||
auto &tree = *p_tree;
|
||||
driver.Push({this->InitRoot(p_fmat, gpair, hess, p_tree)});
|
||||
bst_node_t num_leaves{1};
|
||||
auto expand_set = driver.Pop();
|
||||
|
||||
/**
|
||||
@ -201,14 +200,9 @@ class GloablApproxBuilder {
|
||||
// candidates that can be applied.
|
||||
std::vector<CPUExpandEntry> applied;
|
||||
for (auto const &candidate : expand_set) {
|
||||
if (!candidate.IsValid(param_, num_leaves)) {
|
||||
continue;
|
||||
}
|
||||
evaluator_.ApplyTreeSplit(candidate, p_tree);
|
||||
applied.push_back(candidate);
|
||||
num_leaves++;
|
||||
int left_child_nidx = tree[candidate.nid].LeftChild();
|
||||
if (CPUExpandEntry::ChildIsValid(param_, p_tree->GetDepth(left_child_nidx), num_leaves)) {
|
||||
if (driver.IsChildValid(candidate)) {
|
||||
valid_candidates.emplace_back(candidate);
|
||||
}
|
||||
}
|
||||
|
||||
@ -62,7 +62,7 @@ DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam);
|
||||
#endif // !defined(GTEST_TEST)
|
||||
|
||||
/**
|
||||
* \struct DeviceHistogram
|
||||
* \struct DeviceHistogramStorage
|
||||
*
|
||||
* \summary Data storage for node histograms on device. Automatically expands.
|
||||
*
|
||||
@ -72,20 +72,27 @@ DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam);
|
||||
* \author Rory
|
||||
* \date 28/07/2018
|
||||
*/
|
||||
template <typename GradientSumT, size_t kStopGrowingSize = 1 << 26>
|
||||
class DeviceHistogram {
|
||||
template <typename GradientSumT, size_t kStopGrowingSize = 1 << 28>
|
||||
class DeviceHistogramStorage {
|
||||
private:
|
||||
/*! \brief Map nidx to starting index of its histogram. */
|
||||
std::map<int, size_t> nidx_map_;
|
||||
// Large buffer of zeroed memory, caches histograms
|
||||
dh::device_vector<typename GradientSumT::ValueT> data_;
|
||||
// If we run out of storage allocate one histogram at a time
|
||||
// in overflow. Not cached, overwritten when a new histogram
|
||||
// is requested
|
||||
dh::device_vector<typename GradientSumT::ValueT> overflow_;
|
||||
std::map<int, size_t> overflow_nidx_map_;
|
||||
int n_bins_;
|
||||
int device_id_;
|
||||
static constexpr size_t kNumItemsInGradientSum =
|
||||
sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT);
|
||||
static_assert(kNumItemsInGradientSum == 2,
|
||||
"Number of items in gradient type should be 2.");
|
||||
static_assert(kNumItemsInGradientSum == 2, "Number of items in gradient type should be 2.");
|
||||
|
||||
public:
|
||||
// Start with about 16mb
|
||||
DeviceHistogramStorage() { data_.reserve(1 << 22); }
|
||||
void Init(int device_id, int n_bins) {
|
||||
this->n_bins_ = n_bins;
|
||||
this->device_id_ = device_id;
|
||||
@ -93,52 +100,47 @@ class DeviceHistogram {
|
||||
|
||||
void Reset() {
|
||||
auto d_data = data_.data().get();
|
||||
dh::LaunchN(data_.size(),
|
||||
[=] __device__(size_t idx) { d_data[idx] = 0.0f; });
|
||||
dh::LaunchN(data_.size(), [=] __device__(size_t idx) { d_data[idx] = 0.0f; });
|
||||
nidx_map_.clear();
|
||||
overflow_nidx_map_.clear();
|
||||
}
|
||||
bool HistogramExists(int nidx) const {
|
||||
return nidx_map_.find(nidx) != nidx_map_.cend();
|
||||
}
|
||||
int Bins() const {
|
||||
return n_bins_;
|
||||
}
|
||||
size_t HistogramSize() const {
|
||||
return n_bins_ * kNumItemsInGradientSum;
|
||||
return nidx_map_.find(nidx) != nidx_map_.cend() ||
|
||||
overflow_nidx_map_.find(nidx) != overflow_nidx_map_.cend();
|
||||
}
|
||||
int Bins() const { return n_bins_; }
|
||||
size_t HistogramSize() const { return n_bins_ * kNumItemsInGradientSum; }
|
||||
dh::device_vector<typename GradientSumT::ValueT>& Data() { return data_; }
|
||||
|
||||
dh::device_vector<typename GradientSumT::ValueT>& Data() {
|
||||
return data_;
|
||||
}
|
||||
|
||||
void AllocateHistogram(int nidx) {
|
||||
if (HistogramExists(nidx)) return;
|
||||
void AllocateHistograms(const std::vector<int>& new_nidxs) {
|
||||
for (int nidx : new_nidxs) {
|
||||
CHECK(!HistogramExists(nidx));
|
||||
}
|
||||
// Number of items currently used in data
|
||||
const size_t used_size = nidx_map_.size() * HistogramSize();
|
||||
const size_t new_used_size = used_size + HistogramSize();
|
||||
if (data_.size() >= kStopGrowingSize) {
|
||||
// Recycle histogram memory
|
||||
if (new_used_size <= data_.size()) {
|
||||
// no need to remove old node, just insert the new one.
|
||||
nidx_map_[nidx] = used_size;
|
||||
// memset histogram size in bytes
|
||||
} else {
|
||||
std::pair<int, size_t> old_entry = *nidx_map_.begin();
|
||||
nidx_map_.erase(old_entry.first);
|
||||
nidx_map_[nidx] = old_entry.second;
|
||||
const size_t new_used_size = used_size + HistogramSize() * new_nidxs.size();
|
||||
if (used_size >= kStopGrowingSize) {
|
||||
// Use overflow
|
||||
// Delete previous entries
|
||||
overflow_nidx_map_.clear();
|
||||
overflow_.resize(HistogramSize() * new_nidxs.size());
|
||||
// Zero memory
|
||||
auto d_data = overflow_.data().get();
|
||||
dh::LaunchN(overflow_.size(),
|
||||
[=] __device__(size_t idx) { d_data[idx] = 0.0; });
|
||||
// Append new histograms
|
||||
for (int nidx : new_nidxs) {
|
||||
overflow_nidx_map_[nidx] = overflow_nidx_map_.size() * HistogramSize();
|
||||
}
|
||||
// Zero recycled memory
|
||||
auto d_data = data_.data().get() + nidx_map_[nidx];
|
||||
dh::LaunchN(n_bins_ * 2,
|
||||
[=] __device__(size_t idx) { d_data[idx] = 0.0f; });
|
||||
} else {
|
||||
// Append new node histogram
|
||||
nidx_map_[nidx] = used_size;
|
||||
// Check there is enough memory for another histogram node
|
||||
if (data_.size() < new_used_size + HistogramSize()) {
|
||||
size_t new_required_memory =
|
||||
std::max(data_.size() * 2, HistogramSize());
|
||||
data_.resize(new_required_memory);
|
||||
CHECK_GE(data_.size(), used_size);
|
||||
// Expand if necessary
|
||||
if (data_.size() < new_used_size) {
|
||||
data_.resize(std::max(data_.size() * 2, new_used_size));
|
||||
}
|
||||
// Append new histograms
|
||||
for (int nidx : new_nidxs) {
|
||||
nidx_map_[nidx] = nidx_map_.size() * HistogramSize();
|
||||
}
|
||||
}
|
||||
|
||||
@ -152,9 +154,16 @@ class DeviceHistogram {
|
||||
*/
|
||||
common::Span<GradientSumT> GetNodeHistogram(int nidx) {
|
||||
CHECK(this->HistogramExists(nidx));
|
||||
auto ptr = data_.data().get() + nidx_map_.at(nidx);
|
||||
return common::Span<GradientSumT>(
|
||||
reinterpret_cast<GradientSumT*>(ptr), n_bins_);
|
||||
|
||||
if (nidx_map_.find(nidx) != nidx_map_.cend()) {
|
||||
// Fetch from normal cache
|
||||
auto ptr = data_.data().get() + nidx_map_.at(nidx);
|
||||
return common::Span<GradientSumT>(reinterpret_cast<GradientSumT*>(ptr), n_bins_);
|
||||
} else {
|
||||
// Fetch from overflow
|
||||
auto ptr = overflow_.data().get() + overflow_nidx_map_.at(nidx);
|
||||
return common::Span<GradientSumT>(reinterpret_cast<GradientSumT*>(ptr), n_bins_);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -171,7 +180,7 @@ struct GPUHistMakerDevice {
|
||||
BatchParam batch_param;
|
||||
|
||||
std::unique_ptr<RowPartitioner> row_partitioner;
|
||||
DeviceHistogram<GradientSumT> hist{};
|
||||
DeviceHistogramStorage<GradientSumT> hist{};
|
||||
|
||||
dh::caching_device_vector<GradientPair> d_gpair; // storage for gpair;
|
||||
common::Span<GradientPair> gpair;
|
||||
@ -195,6 +204,7 @@ struct GPUHistMakerDevice {
|
||||
|
||||
std::unique_ptr<FeatureGroups> feature_groups;
|
||||
|
||||
|
||||
GPUHistMakerDevice(Context const* ctx, EllpackPageImpl const* _page,
|
||||
common::Span<FeatureType const> _feature_types, bst_uint _n_rows,
|
||||
TrainParam _param, uint32_t column_sampler_seed, uint32_t n_features,
|
||||
@ -322,7 +332,6 @@ struct GPUHistMakerDevice {
|
||||
}
|
||||
|
||||
void BuildHist(int nidx) {
|
||||
hist.AllocateHistogram(nidx);
|
||||
auto d_node_hist = hist.GetNodeHistogram(nidx);
|
||||
auto d_ridx = row_partitioner->GetRows(nidx);
|
||||
BuildGradientHistogram(page->GetDeviceAccessor(ctx_->gpu_id),
|
||||
@ -330,8 +339,12 @@ struct GPUHistMakerDevice {
|
||||
d_ridx, d_node_hist, histogram_rounding);
|
||||
}
|
||||
|
||||
void SubtractionTrick(int nidx_parent, int nidx_histogram,
|
||||
int nidx_subtraction) {
|
||||
// Attempt to do subtraction trick
|
||||
// return true if succeeded
|
||||
bool SubtractionTrick(int nidx_parent, int nidx_histogram, int nidx_subtraction) {
|
||||
if (!hist.HistogramExists(nidx_histogram) || !hist.HistogramExists(nidx_parent)) {
|
||||
return false;
|
||||
}
|
||||
auto d_node_hist_parent = hist.GetNodeHistogram(nidx_parent);
|
||||
auto d_node_hist_histogram = hist.GetNodeHistogram(nidx_histogram);
|
||||
auto d_node_hist_subtraction = hist.GetNodeHistogram(nidx_subtraction);
|
||||
@ -340,12 +353,7 @@ struct GPUHistMakerDevice {
|
||||
d_node_hist_subtraction[idx] =
|
||||
d_node_hist_parent[idx] - d_node_hist_histogram[idx];
|
||||
});
|
||||
}
|
||||
|
||||
bool CanDoSubtractionTrick(int nidx_parent, int nidx_histogram, int nidx_subtraction) {
|
||||
// Make sure histograms are already allocated
|
||||
hist.AllocateHistogram(nidx_subtraction);
|
||||
return hist.HistogramExists(nidx_histogram) && hist.HistogramExists(nidx_parent);
|
||||
return true;
|
||||
}
|
||||
|
||||
void UpdatePosition(const GPUExpandEntry &e, RegTree* p_tree) {
|
||||
@ -505,13 +513,15 @@ struct GPUHistMakerDevice {
|
||||
row_partitioner.reset();
|
||||
}
|
||||
|
||||
void AllReduceHist(int nidx, dh::AllReducer* reducer) {
|
||||
// num histograms is the number of contiguous histograms in memory to reduce over
|
||||
void AllReduceHist(int nidx, dh::AllReducer* reducer, int num_histograms) {
|
||||
monitor.Start("AllReduce");
|
||||
auto d_node_hist = hist.GetNodeHistogram(nidx).data();
|
||||
reducer->AllReduceSum(
|
||||
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
|
||||
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
|
||||
page->Cuts().TotalBins() * (sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT)));
|
||||
reducer->AllReduceSum(reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
|
||||
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
|
||||
page->Cuts().TotalBins() *
|
||||
(sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT)) *
|
||||
num_histograms);
|
||||
|
||||
monitor.Stop("AllReduce");
|
||||
}
|
||||
@ -519,33 +529,50 @@ struct GPUHistMakerDevice {
|
||||
/**
|
||||
* \brief Build GPU local histograms for the left and right child of some parent node
|
||||
*/
|
||||
void BuildHistLeftRight(const GPUExpandEntry &candidate, int nidx_left,
|
||||
int nidx_right, dh::AllReducer* reducer) {
|
||||
auto build_hist_nidx = nidx_left;
|
||||
auto subtraction_trick_nidx = nidx_right;
|
||||
void BuildHistLeftRight(std::vector<GPUExpandEntry> const& candidates, dh::AllReducer* reducer,
|
||||
const RegTree& tree) {
|
||||
if (candidates.empty()) return;
|
||||
// Some nodes we will manually compute histograms
|
||||
// others we will do by subtraction
|
||||
std::vector<int> hist_nidx;
|
||||
std::vector<int> subtraction_nidx;
|
||||
for (auto& e : candidates) {
|
||||
// Decide whether to build the left histogram or right histogram
|
||||
// Use sum of Hessian as a heuristic to select node with fewest training instances
|
||||
bool fewer_right = e.split.right_sum.GetHess() < e.split.left_sum.GetHess();
|
||||
if (fewer_right) {
|
||||
hist_nidx.emplace_back(tree[e.nid].RightChild());
|
||||
subtraction_nidx.emplace_back(tree[e.nid].LeftChild());
|
||||
} else {
|
||||
hist_nidx.emplace_back(tree[e.nid].LeftChild());
|
||||
subtraction_nidx.emplace_back(tree[e.nid].RightChild());
|
||||
}
|
||||
}
|
||||
std::vector<int> all_new = hist_nidx;
|
||||
all_new.insert(all_new.end(), subtraction_nidx.begin(), subtraction_nidx.end());
|
||||
// Allocate the histograms
|
||||
// Guaranteed contiguous memory
|
||||
hist.AllocateHistograms(all_new);
|
||||
|
||||
// Decide whether to build the left histogram or right histogram
|
||||
// Use sum of Hessian as a heuristic to select node with fewest training instances
|
||||
bool fewer_right = candidate.split.right_sum.GetHess() < candidate.split.left_sum.GetHess();
|
||||
if (fewer_right) {
|
||||
std::swap(build_hist_nidx, subtraction_trick_nidx);
|
||||
for (auto nidx : hist_nidx) {
|
||||
this->BuildHist(nidx);
|
||||
}
|
||||
|
||||
this->BuildHist(build_hist_nidx);
|
||||
this->AllReduceHist(build_hist_nidx, reducer);
|
||||
// Reduce all in one go
|
||||
// This gives much better latency in a distributed setting
|
||||
// when processing a large batch
|
||||
this->AllReduceHist(hist_nidx.at(0), reducer, hist_nidx.size());
|
||||
|
||||
// Check whether we can use the subtraction trick to calculate the other
|
||||
bool do_subtraction_trick = this->CanDoSubtractionTrick(
|
||||
candidate.nid, build_hist_nidx, subtraction_trick_nidx);
|
||||
for (int i = 0; i < subtraction_nidx.size(); i++) {
|
||||
auto build_hist_nidx = hist_nidx.at(i);
|
||||
auto subtraction_trick_nidx = subtraction_nidx.at(i);
|
||||
auto parent_nidx = candidates.at(i).nid;
|
||||
|
||||
if (do_subtraction_trick) {
|
||||
// Calculate other histogram using subtraction trick
|
||||
this->SubtractionTrick(candidate.nid, build_hist_nidx,
|
||||
subtraction_trick_nidx);
|
||||
} else {
|
||||
// Calculate other histogram manually
|
||||
this->BuildHist(subtraction_trick_nidx);
|
||||
this->AllReduceHist(subtraction_trick_nidx, reducer);
|
||||
if (!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)) {
|
||||
// Calculate other histogram manually
|
||||
this->BuildHist(subtraction_trick_nidx);
|
||||
this->AllReduceHist(subtraction_trick_nidx, reducer, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -605,8 +632,9 @@ struct GPUHistMakerDevice {
|
||||
GradientPairPrecise{}, thrust::plus<GradientPairPrecise>{});
|
||||
rabit::Allreduce<rabit::op::Sum, double>(reinterpret_cast<double*>(&root_sum), 2);
|
||||
|
||||
hist.AllocateHistograms({kRootNIdx});
|
||||
this->BuildHist(kRootNIdx);
|
||||
this->AllReduceHist(kRootNIdx, reducer);
|
||||
this->AllReduceHist(kRootNIdx, reducer, 1);
|
||||
|
||||
// Remember root stats
|
||||
node_sum_gradients[kRootNIdx] = root_sum;
|
||||
@ -624,7 +652,8 @@ struct GPUHistMakerDevice {
|
||||
RegTree* p_tree, dh::AllReducer* reducer,
|
||||
HostDeviceVector<bst_node_t>* p_out_position) {
|
||||
auto& tree = *p_tree;
|
||||
Driver<GPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param.grow_policy));
|
||||
// Process maximum 32 nodes at a time
|
||||
Driver<GPUExpandEntry> driver(param, 32);
|
||||
|
||||
monitor.Start("Reset");
|
||||
this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_);
|
||||
@ -634,48 +663,44 @@ struct GPUHistMakerDevice {
|
||||
driver.Push({ this->InitRoot(p_tree, reducer) });
|
||||
monitor.Stop("InitRoot");
|
||||
|
||||
auto num_leaves = 1;
|
||||
|
||||
// The set of leaves that can be expanded asynchronously
|
||||
auto expand_set = driver.Pop();
|
||||
while (!expand_set.empty()) {
|
||||
auto new_candidates =
|
||||
pinned.GetSpan<GPUExpandEntry>(expand_set.size() * 2, GPUExpandEntry());
|
||||
|
||||
for (auto i = 0ull; i < expand_set.size(); i++) {
|
||||
auto candidate = expand_set.at(i);
|
||||
if (!candidate.IsValid(param, num_leaves)) {
|
||||
continue;
|
||||
}
|
||||
for (auto& candidate : expand_set) {
|
||||
this->ApplySplit(candidate, p_tree);
|
||||
}
|
||||
// Get the candidates we are allowed to expand further
|
||||
// e.g. We do not bother further processing nodes whose children are beyond max depth
|
||||
std::vector<GPUExpandEntry> filtered_expand_set;
|
||||
std::copy_if(expand_set.begin(), expand_set.end(), std::back_inserter(filtered_expand_set),
|
||||
[&](const auto& e) { return driver.IsChildValid(e); });
|
||||
|
||||
num_leaves++;
|
||||
|
||||
auto new_candidates =
|
||||
pinned.GetSpan<GPUExpandEntry>(filtered_expand_set.size() * 2, GPUExpandEntry());
|
||||
|
||||
for (const auto& e : filtered_expand_set) {
|
||||
monitor.Start("UpdatePosition");
|
||||
// Update position is only run when child is valid, instead of right after apply
|
||||
// split (as in approx tree method). Hense we have the finalise position call
|
||||
// in GPU Hist.
|
||||
this->UpdatePosition(e, p_tree);
|
||||
monitor.Stop("UpdatePosition");
|
||||
}
|
||||
|
||||
monitor.Start("BuildHist");
|
||||
this->BuildHistLeftRight(filtered_expand_set, reducer, tree);
|
||||
monitor.Stop("BuildHist");
|
||||
|
||||
for (auto i = 0ull; i < filtered_expand_set.size(); i++) {
|
||||
auto candidate = filtered_expand_set.at(i);
|
||||
int left_child_nidx = tree[candidate.nid].LeftChild();
|
||||
int right_child_nidx = tree[candidate.nid].RightChild();
|
||||
// Only create child entries if needed_
|
||||
if (GPUExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx),
|
||||
num_leaves)) {
|
||||
monitor.Start("UpdatePosition");
|
||||
// Update position is only run when child is valid, instead of right after apply
|
||||
// split (as in approx tree method). Hense we have the finalise position call
|
||||
// in GPU Hist.
|
||||
this->UpdatePosition(candidate, p_tree);
|
||||
monitor.Stop("UpdatePosition");
|
||||
|
||||
monitor.Start("BuildHist");
|
||||
this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer);
|
||||
monitor.Stop("BuildHist");
|
||||
|
||||
monitor.Start("EvaluateSplits");
|
||||
this->EvaluateLeftRightSplits(candidate, left_child_nidx, right_child_nidx, *p_tree,
|
||||
new_candidates.subspan(i * 2, 2));
|
||||
monitor.Stop("EvaluateSplits");
|
||||
} else {
|
||||
// Set default
|
||||
new_candidates[i * 2] = GPUExpandEntry();
|
||||
new_candidates[i * 2 + 1] = GPUExpandEntry();
|
||||
}
|
||||
monitor.Start("EvaluateSplits");
|
||||
this->EvaluateLeftRightSplits(candidate, left_child_nidx, right_child_nidx, *p_tree,
|
||||
new_candidates.subspan(i * 2, 2));
|
||||
monitor.Stop("EvaluateSplits");
|
||||
}
|
||||
dh::DefaultStream().Sync();
|
||||
driver.Push(new_candidates.begin(), new_candidates.end());
|
||||
|
||||
@ -175,10 +175,9 @@ void QuantileHistMaker::Builder::ExpandTree(DMatrix *p_fmat, RegTree *p_tree,
|
||||
HostDeviceVector<bst_node_t> *p_out_position) {
|
||||
monitor_->Start(__func__);
|
||||
|
||||
Driver<CPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param_.grow_policy));
|
||||
Driver<CPUExpandEntry> driver(param_);
|
||||
driver.Push(this->InitRoot(p_fmat, p_tree, gpair_h));
|
||||
auto const &tree = *p_tree;
|
||||
bst_node_t num_leaves{1};
|
||||
auto expand_set = driver.Pop();
|
||||
|
||||
while (!expand_set.empty()) {
|
||||
@ -188,13 +187,9 @@ void QuantileHistMaker::Builder::ExpandTree(DMatrix *p_fmat, RegTree *p_tree,
|
||||
std::vector<CPUExpandEntry> applied;
|
||||
int32_t depth = expand_set.front().depth + 1;
|
||||
for (auto const& candidate : expand_set) {
|
||||
if (!candidate.IsValid(param_, num_leaves)) {
|
||||
continue;
|
||||
}
|
||||
evaluator_->ApplyTreeSplit(candidate, p_tree);
|
||||
applied.push_back(candidate);
|
||||
num_leaves++;
|
||||
if (CPUExpandEntry::ChildIsValid(param_, depth, num_leaves)) {
|
||||
if (driver.IsChildValid(candidate)) {
|
||||
valid_candidates.emplace_back(candidate);
|
||||
}
|
||||
}
|
||||
|
||||
@ -6,41 +6,58 @@ namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
TEST(GpuHist, DriverDepthWise) {
|
||||
Driver<GPUExpandEntry> driver(TrainParam::kDepthWise);
|
||||
TrainParam p;
|
||||
p.InitAllowUnknown(Args{});
|
||||
p.grow_policy = TrainParam::kDepthWise;
|
||||
Driver<GPUExpandEntry> driver(p, 2);
|
||||
EXPECT_TRUE(driver.Pop().empty());
|
||||
DeviceSplitCandidate split;
|
||||
split.loss_chg = 1.0f;
|
||||
GPUExpandEntry root(0, 0, split, .0f, .0f, .0f);
|
||||
split.left_sum = {0.0f, 1.0f};
|
||||
split.right_sum = {0.0f, 1.0f};
|
||||
GPUExpandEntry root(0, 0, split, 2.0f, 1.0f, 1.0f);
|
||||
driver.Push({root});
|
||||
EXPECT_EQ(driver.Pop().front().nid, 0);
|
||||
driver.Push({GPUExpandEntry{1, 1, split, .0f, .0f, .0f}});
|
||||
driver.Push({GPUExpandEntry{2, 1, split, .0f, .0f, .0f}});
|
||||
driver.Push({GPUExpandEntry{3, 2, split, .0f, .0f, .0f}});
|
||||
// Should return entries from level 1
|
||||
driver.Push({GPUExpandEntry{1, 1, split, 2.0f, 1.0f, 1.0f}});
|
||||
driver.Push({GPUExpandEntry{2, 1, split, 2.0f, 1.0f, 1.0f}});
|
||||
driver.Push({GPUExpandEntry{3, 1, split, 2.0f, 1.0f, 1.0f}});
|
||||
driver.Push({GPUExpandEntry{4, 2, split, 2.0f, 1.0f, 1.0f}});
|
||||
// Should return 2 entries from level 1
|
||||
// as we limited the driver to pop maximum 2 nodes
|
||||
auto res = driver.Pop();
|
||||
EXPECT_EQ(res.size(), 2);
|
||||
for (auto &e : res) {
|
||||
EXPECT_EQ(e.depth, 1);
|
||||
}
|
||||
|
||||
// Should now return 1 entry from level 1
|
||||
res = driver.Pop();
|
||||
EXPECT_EQ(res[0].depth, 2);
|
||||
EXPECT_EQ(res.size(), 1);
|
||||
EXPECT_EQ(res.at(0).depth, 1);
|
||||
|
||||
res = driver.Pop();
|
||||
EXPECT_EQ(res.at(0).depth, 2);
|
||||
EXPECT_TRUE(driver.Pop().empty());
|
||||
}
|
||||
|
||||
TEST(GpuHist, DriverLossGuided) {
|
||||
DeviceSplitCandidate high_gain;
|
||||
high_gain.left_sum = {0.0f, 1.0f};
|
||||
high_gain.right_sum = {0.0f, 1.0f};
|
||||
high_gain.loss_chg = 5.0f;
|
||||
DeviceSplitCandidate low_gain;
|
||||
DeviceSplitCandidate low_gain = high_gain;
|
||||
low_gain.loss_chg = 1.0f;
|
||||
|
||||
Driver<GPUExpandEntry> driver(TrainParam::kLossGuide);
|
||||
TrainParam p;
|
||||
p.grow_policy=TrainParam::kLossGuide;
|
||||
Driver<GPUExpandEntry> driver(p);
|
||||
EXPECT_TRUE(driver.Pop().empty());
|
||||
GPUExpandEntry root(0, 0, high_gain, .0f, .0f, .0f);
|
||||
GPUExpandEntry root(0, 0, high_gain, 2.0f, 1.0f, 1.0f );
|
||||
driver.Push({root});
|
||||
EXPECT_EQ(driver.Pop().front().nid, 0);
|
||||
// Select high gain first
|
||||
driver.Push({GPUExpandEntry{1, 1, low_gain, .0f, .0f, .0f}});
|
||||
driver.Push({GPUExpandEntry{2, 2, high_gain, .0f, .0f, .0f}});
|
||||
driver.Push({GPUExpandEntry{1, 1, low_gain, 2.0f, 1.0f, 1.0f}});
|
||||
driver.Push({GPUExpandEntry{2, 2, high_gain, 2.0f, 1.0f, 1.0f}});
|
||||
auto res = driver.Pop();
|
||||
EXPECT_EQ(res.size(), 1);
|
||||
EXPECT_EQ(res[0].nid, 2);
|
||||
@ -49,8 +66,8 @@ TEST(GpuHist, DriverLossGuided) {
|
||||
EXPECT_EQ(res[0].nid, 1);
|
||||
|
||||
// If equal gain, use nid
|
||||
driver.Push({GPUExpandEntry{2, 1, low_gain, .0f, .0f, .0f}});
|
||||
driver.Push({GPUExpandEntry{1, 1, low_gain, .0f, .0f, .0f}});
|
||||
driver.Push({GPUExpandEntry{2, 1, low_gain, 2.0f, 1.0f, 1.0f}});
|
||||
driver.Push({GPUExpandEntry{1, 1, low_gain, 2.0f, 1.0f, 1.0f}});
|
||||
res = driver.Pop();
|
||||
EXPECT_EQ(res[0].nid, 1);
|
||||
res = driver.Pop();
|
||||
|
||||
@ -95,7 +95,6 @@ TEST(Histogram, GPUDeterministic) {
|
||||
std::vector<int> shm_sizes{48 * 1024, 64 * 1024, 160 * 1024};
|
||||
for (bool is_dense : is_dense_array) {
|
||||
for (int shm_size : shm_sizes) {
|
||||
TestDeterministicHistogram<GradientPair>(is_dense, shm_size);
|
||||
TestDeterministicHistogram<GradientPairPrecise>(is_dense, shm_size);
|
||||
}
|
||||
}
|
||||
|
||||
@ -27,31 +27,40 @@ TEST(GpuHist, DeviceHistogram) {
|
||||
// Ensures that node allocates correctly after reaching `kStopGrowingSize`.
|
||||
dh::safe_cuda(cudaSetDevice(0));
|
||||
constexpr size_t kNBins = 128;
|
||||
constexpr size_t kNNodes = 4;
|
||||
constexpr int kNNodes = 4;
|
||||
constexpr size_t kStopGrowing = kNNodes * kNBins * 2u;
|
||||
DeviceHistogram<GradientPairPrecise, kStopGrowing> histogram;
|
||||
DeviceHistogramStorage<GradientPairPrecise, kStopGrowing> histogram;
|
||||
histogram.Init(0, kNBins);
|
||||
for (size_t i = 0; i < kNNodes; ++i) {
|
||||
histogram.AllocateHistogram(i);
|
||||
for (int i = 0; i < kNNodes; ++i) {
|
||||
histogram.AllocateHistograms({i});
|
||||
}
|
||||
histogram.Reset();
|
||||
ASSERT_EQ(histogram.Data().size(), kStopGrowing);
|
||||
|
||||
// Use allocated memory but do not erase nidx_map.
|
||||
for (size_t i = 0; i < kNNodes; ++i) {
|
||||
histogram.AllocateHistogram(i);
|
||||
for (int i = 0; i < kNNodes; ++i) {
|
||||
histogram.AllocateHistograms({i});
|
||||
}
|
||||
for (size_t i = 0; i < kNNodes; ++i) {
|
||||
for (int i = 0; i < kNNodes; ++i) {
|
||||
ASSERT_TRUE(histogram.HistogramExists(i));
|
||||
}
|
||||
|
||||
// Erase existing nidx_map.
|
||||
for (size_t i = kNNodes; i < kNNodes * 2; ++i) {
|
||||
histogram.AllocateHistogram(i);
|
||||
}
|
||||
for (size_t i = 0; i < kNNodes; ++i) {
|
||||
ASSERT_FALSE(histogram.HistogramExists(i));
|
||||
// Add two new nodes
|
||||
histogram.AllocateHistograms({kNNodes});
|
||||
histogram.AllocateHistograms({kNNodes + 1});
|
||||
|
||||
// Old cached nodes should still exist
|
||||
for (int i = 0; i < kNNodes; ++i) {
|
||||
ASSERT_TRUE(histogram.HistogramExists(i));
|
||||
}
|
||||
|
||||
// Should be deleted
|
||||
ASSERT_FALSE(histogram.HistogramExists(kNNodes));
|
||||
// Most recent node should exist
|
||||
ASSERT_TRUE(histogram.HistogramExists(kNNodes + 1));
|
||||
|
||||
// Add same node again - should fail
|
||||
EXPECT_ANY_THROW(histogram.AllocateHistograms({kNNodes + 1}););
|
||||
}
|
||||
|
||||
std::vector<GradientPairPrecise> GetHostHistGpair() {
|
||||
@ -96,9 +105,9 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
||||
|
||||
thrust::host_vector<common::CompressedByteT> h_gidx_buffer (page->gidx_buffer.HostVector());
|
||||
maker.row_partitioner.reset(new RowPartitioner(0, kNRows));
|
||||
maker.hist.AllocateHistogram(0);
|
||||
maker.hist.AllocateHistograms({0});
|
||||
maker.gpair = gpair.DeviceSpan();
|
||||
maker.histogram_rounding = CreateRoundingFactor<GradientSumT>(maker.gpair);;
|
||||
maker.histogram_rounding = CreateRoundingFactor<GradientSumT>(maker.gpair);
|
||||
|
||||
BuildGradientHistogram(
|
||||
page->GetDeviceAccessor(0), maker.feature_groups->DeviceAccessor(0),
|
||||
@ -106,7 +115,7 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
||||
maker.hist.GetNodeHistogram(0), maker.histogram_rounding,
|
||||
!use_shared_memory_histograms);
|
||||
|
||||
DeviceHistogram<GradientSumT>& d_hist = maker.hist;
|
||||
DeviceHistogramStorage<GradientSumT>& d_hist = maker.hist;
|
||||
|
||||
auto node_histogram = d_hist.GetNodeHistogram(0);
|
||||
// d_hist.data stored in float, not gradient pair
|
||||
@ -129,12 +138,10 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
||||
|
||||
TEST(GpuHist, BuildHistGlobalMem) {
|
||||
TestBuildHist<GradientPairPrecise>(false);
|
||||
TestBuildHist<GradientPair>(false);
|
||||
}
|
||||
|
||||
TEST(GpuHist, BuildHistSharedMem) {
|
||||
TestBuildHist<GradientPairPrecise>(true);
|
||||
TestBuildHist<GradientPair>(true);
|
||||
}
|
||||
|
||||
HistogramCutsWrapper GetHostCutMatrix () {
|
||||
@ -198,7 +205,7 @@ TEST(GpuHist, EvaluateRootSplit) {
|
||||
|
||||
// Initialize GPUHistMakerDevice::hist
|
||||
maker.hist.Init(0, (max_bins - 1) * kNCols);
|
||||
maker.hist.AllocateHistogram(0);
|
||||
maker.hist.AllocateHistograms({0});
|
||||
// Each row of hist_gpair represents gpairs for one feature.
|
||||
// Each entry represents a bin.
|
||||
std::vector<GradientPairPrecise> hist_gpair = GetHostHistGpair();
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user