Reduce device synchronisation (#5631)

* Reduce device synchronisation

* Initialise pinned memory
This commit is contained in:
Rory Mitchell
2020-05-07 21:19:46 +12:00
committed by GitHub
parent 9910265064
commit fcf57823b6
7 changed files with 260 additions and 118 deletions

View File

@@ -30,6 +30,7 @@
#include "gpu_hist/row_partitioner.cuh"
#include "gpu_hist/histogram.cuh"
#include "gpu_hist/evaluate_splits.cuh"
#include "gpu_hist/driver.cuh"
namespace xgboost {
namespace tree {
@@ -57,58 +58,6 @@ struct GPUHistMakerTrainParam
DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam);
#endif // !defined(GTEST_TEST)
struct ExpandEntry {
int nid;
int depth;
DeviceSplitCandidate split;
uint64_t timestamp;
ExpandEntry() = default;
ExpandEntry(int nid, int depth, DeviceSplitCandidate split,
uint64_t timestamp)
: nid(nid), depth(depth), split(std::move(split)), timestamp(timestamp) {}
bool IsValid(const TrainParam& param, int num_leaves) const {
if (split.loss_chg <= kRtEps) return false;
if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) {
return false;
}
if (split.loss_chg < param.min_split_loss) { return false; }
if (param.max_depth > 0 && depth == param.max_depth) {return false; }
if (param.max_leaves > 0 && num_leaves == param.max_leaves) { return false; }
return true;
}
static bool ChildIsValid(const TrainParam& param, int depth, int num_leaves) {
if (param.max_depth > 0 && depth >= param.max_depth) return false;
if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false;
return true;
}
friend std::ostream& operator<<(std::ostream& os, const ExpandEntry& e) {
os << "ExpandEntry: \n";
os << "nidx: " << e.nid << "\n";
os << "depth: " << e.depth << "\n";
os << "loss: " << e.split.loss_chg << "\n";
os << "left_sum: " << e.split.left_sum << "\n";
os << "right_sum: " << e.split.right_sum << "\n";
return os;
}
};
inline static bool DepthWise(const ExpandEntry& lhs, const ExpandEntry& rhs) {
if (lhs.depth == rhs.depth) {
return lhs.timestamp > rhs.timestamp; // favor small timestamp
} else {
return lhs.depth > rhs.depth; // favor small depth
}
}
inline static bool LossGuide(const ExpandEntry& lhs, const ExpandEntry& rhs) {
if (lhs.split.loss_chg == rhs.split.loss_chg) {
return lhs.timestamp > rhs.timestamp; // favor small timestamp
} else {
return lhs.split.loss_chg < rhs.split.loss_chg; // favor large loss_chg
}
}
/**
* \struct DeviceHistogram
*
@@ -243,6 +192,8 @@ struct GPUHistMakerDevice {
GradientSumT histogram_rounding;
dh::PinnedMemory pinned;
std::vector<cudaStream_t> streams{};
common::Monitor monitor;
@@ -250,11 +201,6 @@ struct GPUHistMakerDevice {
common::ColumnSampler column_sampler;
FeatureInteractionConstraintDevice interaction_constraints;
using ExpandQueue =
std::priority_queue<ExpandEntry, std::vector<ExpandEntry>,
std::function<bool(ExpandEntry, ExpandEntry)>>;
std::unique_ptr<ExpandQueue> qexpand;
std::unique_ptr<GradientBasedSampler> sampler;
GPUHistMakerDevice(int _device_id,
@@ -314,11 +260,6 @@ struct GPUHistMakerDevice {
// Note that the column sampler must be passed by value because it is not
// thread safe
void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns) {
if (param.grow_policy == TrainParam::kLossGuide) {
qexpand.reset(new ExpandQueue(LossGuide));
} else {
qexpand.reset(new ExpandQueue(DepthWise));
}
this->column_sampler.Init(num_columns, param.colsample_bynode,
param.colsample_bylevel, param.colsample_bytree);
dh::safe_cuda(cudaSetDevice(device_id));
@@ -370,9 +311,9 @@ struct GPUHistMakerDevice {
return result.front();
}
std::vector<DeviceSplitCandidate> EvaluateLeftRightSplits(
ExpandEntry candidate, int left_nidx, int right_nidx,
const RegTree& tree) {
void EvaluateLeftRightSplits(
ExpandEntry candidate, int left_nidx, int right_nidx, const RegTree& tree,
common::Span<ExpandEntry> pinned_candidates_out) {
dh::TemporaryArray<DeviceSplitCandidate> splits_out(2);
GPUTrainingParam gpu_param(param);
auto left_sampled_features =
@@ -412,12 +353,19 @@ struct GPUHistMakerDevice {
hist.GetNodeHistogram(right_nidx),
node_value_constraints[right_nidx],
dh::ToSpan(monotone_constraints)};
EvaluateSplits(dh::ToSpan(splits_out), left, right);
std::vector<DeviceSplitCandidate> result(2);
dh::safe_cuda(cudaMemcpy(result.data(), splits_out.data().get(),
sizeof(DeviceSplitCandidate) * splits_out.size(),
cudaMemcpyDeviceToHost));
return result;
auto d_splits_out = dh::ToSpan(splits_out);
EvaluateSplits(d_splits_out, left, right);
dh::TemporaryArray<ExpandEntry> entries(2);
auto d_entries = entries.data().get();
dh::LaunchN(device_id, 1, [=] __device__(size_t idx) {
d_entries[0] =
ExpandEntry(left_nidx, candidate.depth + 1, d_splits_out[0]);
d_entries[1] =
ExpandEntry(right_nidx, candidate.depth + 1, d_splits_out[1]);
});
dh::safe_cuda(cudaMemcpyAsync(
pinned_candidates_out.data(), entries.data().get(),
sizeof(ExpandEntry) * entries.size(), cudaMemcpyDeviceToHost));
}
void BuildHist(int nidx) {
@@ -637,7 +585,7 @@ struct GPUHistMakerDevice {
tree[candidate.nid].RightChild());
}
void InitRoot(RegTree* p_tree, dh::AllReducer* reducer) {
ExpandEntry InitRoot(RegTree* p_tree, dh::AllReducer* reducer) {
constexpr bst_node_t kRootNIdx = 0;
dh::XGBCachingDeviceAllocator<char> alloc;
GradientPair root_sum = thrust::reduce(
@@ -662,61 +610,66 @@ struct GPUHistMakerDevice {
// Generate first split
auto split = this->EvaluateRootSplit(root_sum);
qexpand->push(
ExpandEntry(kRootNIdx, p_tree->GetDepth(kRootNIdx), split, 0));
return ExpandEntry(kRootNIdx, p_tree->GetDepth(kRootNIdx), split);
}
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat,
RegTree* p_tree, dh::AllReducer* reducer) {
auto& tree = *p_tree;
Driver driver(static_cast<TrainParam::TreeGrowPolicy>(param.grow_policy));
monitor.Start("Reset");
this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_);
monitor.Stop("Reset");
monitor.Start("InitRoot");
this->InitRoot(p_tree, reducer);
driver.Push({ this->InitRoot(p_tree, reducer) });
monitor.Stop("InitRoot");
auto timestamp = qexpand->size();
auto num_leaves = 1;
while (!qexpand->empty()) {
ExpandEntry candidate = qexpand->top();
qexpand->pop();
if (!candidate.IsValid(param, num_leaves)) {
continue;
}
this->ApplySplit(candidate, p_tree);
num_leaves++;
int left_child_nidx = tree[candidate.nid].LeftChild();
int right_child_nidx = tree[candidate.nid].RightChild();
// Only create child entries if needed
if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx),
num_leaves)) {
monitor.Start("UpdatePosition");
this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]);
monitor.Stop("UpdatePosition");
monitor.Start("BuildHist");
this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer);
monitor.Stop("BuildHist");
monitor.Start("EvaluateSplits");
auto splits = this->EvaluateLeftRightSplits(candidate, left_child_nidx,
right_child_nidx,
*p_tree);
monitor.Stop("EvaluateSplits");
qexpand->push(ExpandEntry(left_child_nidx,
tree.GetDepth(left_child_nidx), splits.at(0),
timestamp++));
qexpand->push(ExpandEntry(right_child_nidx,
tree.GetDepth(right_child_nidx),
splits.at(1), timestamp++));
// The set of leaves that can be expanded asynchronously
auto expand_set = driver.Pop();
while (!expand_set.empty()) {
auto new_candidates =
pinned.GetSpan<ExpandEntry>(expand_set.size() * 2, ExpandEntry());
for (auto i = 0ull; i < expand_set.size(); i++) {
auto candidate = expand_set.at(i);
if (!candidate.IsValid(param, num_leaves)) {
continue;
}
this->ApplySplit(candidate, p_tree);
num_leaves++;
int left_child_nidx = tree[candidate.nid].LeftChild();
int right_child_nidx = tree[candidate.nid].RightChild();
// Only create child entries if needed
if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx),
num_leaves)) {
monitor.Start("UpdatePosition");
this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]);
monitor.Stop("UpdatePosition");
monitor.Start("BuildHist");
this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer);
monitor.Stop("BuildHist");
monitor.Start("EvaluateSplits");
this->EvaluateLeftRightSplits(candidate, left_child_nidx,
right_child_nidx, *p_tree,
new_candidates.subspan(i * 2, 2));
monitor.Stop("EvaluateSplits");
} else {
// Set default
new_candidates[i * 2] = ExpandEntry();
new_candidates[i * 2 + 1] = ExpandEntry();
}
}
dh::safe_cuda(cudaDeviceSynchronize());
driver.Push(new_candidates.begin(), new_candidates.end());
expand_set = driver.Pop();
}
monitor.Start("FinalisePosition");