Reduce device synchronisation (#5631)

* Reduce device synchronisation

* Initialise pinned memory
This commit is contained in:
Rory Mitchell 2020-05-07 21:19:46 +12:00 committed by GitHub
parent 9910265064
commit fcf57823b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 260 additions and 118 deletions

View File

@ -503,6 +503,15 @@ struct PinnedMemory {
return xgboost::common::Span<T>(static_cast<T *>(temp_storage), size);
}
template <typename T>
xgboost::common::Span<T> GetSpan(size_t size, T init) {
auto result = this->GetSpan<T>(size);
for (auto &e : result) {
e = init;
}
return result;
}
void Free() {
if (temp_storage != nullptr) {
safe_cuda(cudaFreeHost(temp_storage));

View File

@ -0,0 +1,120 @@
/*!
* Copyright 2020 by XGBoost Contributors
*/
#ifndef DRIVER_CUH_
#define DRIVER_CUH_
#include <xgboost/span.h>
#include <queue>
#include "../param.h"
#include "evaluate_splits.cuh"
namespace xgboost {
namespace tree {
struct ExpandEntry {
int nid;
int depth;
DeviceSplitCandidate split;
ExpandEntry() = default;
XGBOOST_DEVICE ExpandEntry(int nid, int depth, DeviceSplitCandidate split)
: nid(nid), depth(depth), split(std::move(split)) {}
bool IsValid(const TrainParam& param, int num_leaves) const {
if (split.loss_chg <= kRtEps) return false;
if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) {
return false;
}
if (split.loss_chg < param.min_split_loss) {
return false;
}
if (param.max_depth > 0 && depth == param.max_depth) {
return false;
}
if (param.max_leaves > 0 && num_leaves == param.max_leaves) {
return false;
}
return true;
}
static bool ChildIsValid(const TrainParam& param, int depth, int num_leaves) {
if (param.max_depth > 0 && depth >= param.max_depth) return false;
if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false;
return true;
}
friend std::ostream& operator<<(std::ostream& os, const ExpandEntry& e) {
os << "ExpandEntry: \n";
os << "nidx: " << e.nid << "\n";
os << "depth: " << e.depth << "\n";
os << "loss: " << e.split.loss_chg << "\n";
os << "left_sum: " << e.split.left_sum << "\n";
os << "right_sum: " << e.split.right_sum << "\n";
return os;
}
};
inline bool DepthWise(const ExpandEntry& lhs, const ExpandEntry& rhs) {
return lhs.depth > rhs.depth; // favor small depth
}
inline bool LossGuide(const ExpandEntry& lhs, const ExpandEntry& rhs) {
if (lhs.split.loss_chg == rhs.split.loss_chg) {
return lhs.nid > rhs.nid; // favor small timestamp
} else {
return lhs.split.loss_chg < rhs.split.loss_chg; // favor large loss_chg
}
}
// Drives execution of tree building on device
class Driver {
using ExpandQueue =
std::priority_queue<ExpandEntry, std::vector<ExpandEntry>,
std::function<bool(ExpandEntry, ExpandEntry)>>;
public:
explicit Driver(TrainParam::TreeGrowPolicy policy)
: policy_(policy),
queue_(policy == TrainParam::kDepthWise ? DepthWise : LossGuide) {}
template <typename EntryIterT>
void Push(EntryIterT begin,EntryIterT end) {
for (auto it = begin; it != end; ++it) {
const ExpandEntry& e = *it;
if (e.split.loss_chg > kRtEps) {
queue_.push(e);
}
}
}
void Push(const std::vector<ExpandEntry> &entries) {
this->Push(entries.begin(), entries.end());
}
// Return the set of nodes to be expanded
// This set has no dependencies between entries so they may be expanded in
// parallel or asynchronously
std::vector<ExpandEntry> Pop() {
if (queue_.empty()) return {};
// Return a single entry for loss guided mode
if (policy_ == TrainParam::kLossGuide) {
ExpandEntry e = queue_.top();
queue_.pop();
return {e};
}
// Return nodes on same level for depth wise
std::vector<ExpandEntry> result;
ExpandEntry e = queue_.top();
int level = e.depth;
while (e.depth == level && !queue_.empty()) {
queue_.pop();
result.emplace_back(e);
if (!queue_.empty()) {
e = queue_.top();
}
}
return result;
}
private:
TrainParam::TreeGrowPolicy policy_;
ExpandQueue queue_;
};
} // namespace tree
} // namespace xgboost
#endif // DRIVER_CUH_

View File

@ -61,6 +61,7 @@ class RowPartitioner {
dh::caching_device_vector<int64_t>
left_counts_; // Useful to keep a bunch of zeroed memory for sort position
std::vector<cudaStream_t> streams_;
dh::PinnedMemory pinned_;
public:
RowPartitioner(int device_idx, size_t num_rows);
@ -129,12 +130,12 @@ class RowPartitioner {
d_position[idx] = new_position;
});
// Overlap device to host memory copy (left_count) with sort
int64_t left_count;
int64_t &left_count = pinned_.GetSpan<int64_t>(1)[0];
dh::safe_cuda(cudaMemcpyAsync(&left_count, d_left_count, sizeof(int64_t),
cudaMemcpyDeviceToHost, streams_[0]));
SortPositionAndCopy(segment, left_nidx, right_nidx, d_left_count,
streams_[1]);
SortPositionAndCopy(segment, left_nidx, right_nidx, d_left_count, streams_[1]
);
dh::safe_cuda(cudaStreamSynchronize(streams_[0]));
CHECK_LE(left_count, segment.Size());

View File

@ -30,6 +30,7 @@
#include "gpu_hist/row_partitioner.cuh"
#include "gpu_hist/histogram.cuh"
#include "gpu_hist/evaluate_splits.cuh"
#include "gpu_hist/driver.cuh"
namespace xgboost {
namespace tree {
@ -57,58 +58,6 @@ struct GPUHistMakerTrainParam
DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam);
#endif // !defined(GTEST_TEST)
struct ExpandEntry {
int nid;
int depth;
DeviceSplitCandidate split;
uint64_t timestamp;
ExpandEntry() = default;
ExpandEntry(int nid, int depth, DeviceSplitCandidate split,
uint64_t timestamp)
: nid(nid), depth(depth), split(std::move(split)), timestamp(timestamp) {}
bool IsValid(const TrainParam& param, int num_leaves) const {
if (split.loss_chg <= kRtEps) return false;
if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) {
return false;
}
if (split.loss_chg < param.min_split_loss) { return false; }
if (param.max_depth > 0 && depth == param.max_depth) {return false; }
if (param.max_leaves > 0 && num_leaves == param.max_leaves) { return false; }
return true;
}
static bool ChildIsValid(const TrainParam& param, int depth, int num_leaves) {
if (param.max_depth > 0 && depth >= param.max_depth) return false;
if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false;
return true;
}
friend std::ostream& operator<<(std::ostream& os, const ExpandEntry& e) {
os << "ExpandEntry: \n";
os << "nidx: " << e.nid << "\n";
os << "depth: " << e.depth << "\n";
os << "loss: " << e.split.loss_chg << "\n";
os << "left_sum: " << e.split.left_sum << "\n";
os << "right_sum: " << e.split.right_sum << "\n";
return os;
}
};
inline static bool DepthWise(const ExpandEntry& lhs, const ExpandEntry& rhs) {
if (lhs.depth == rhs.depth) {
return lhs.timestamp > rhs.timestamp; // favor small timestamp
} else {
return lhs.depth > rhs.depth; // favor small depth
}
}
inline static bool LossGuide(const ExpandEntry& lhs, const ExpandEntry& rhs) {
if (lhs.split.loss_chg == rhs.split.loss_chg) {
return lhs.timestamp > rhs.timestamp; // favor small timestamp
} else {
return lhs.split.loss_chg < rhs.split.loss_chg; // favor large loss_chg
}
}
/**
* \struct DeviceHistogram
*
@ -243,6 +192,8 @@ struct GPUHistMakerDevice {
GradientSumT histogram_rounding;
dh::PinnedMemory pinned;
std::vector<cudaStream_t> streams{};
common::Monitor monitor;
@ -250,11 +201,6 @@ struct GPUHistMakerDevice {
common::ColumnSampler column_sampler;
FeatureInteractionConstraintDevice interaction_constraints;
using ExpandQueue =
std::priority_queue<ExpandEntry, std::vector<ExpandEntry>,
std::function<bool(ExpandEntry, ExpandEntry)>>;
std::unique_ptr<ExpandQueue> qexpand;
std::unique_ptr<GradientBasedSampler> sampler;
GPUHistMakerDevice(int _device_id,
@ -314,11 +260,6 @@ struct GPUHistMakerDevice {
// Note that the column sampler must be passed by value because it is not
// thread safe
void Reset(HostDeviceVector<GradientPair>* dh_gpair, DMatrix* dmat, int64_t num_columns) {
if (param.grow_policy == TrainParam::kLossGuide) {
qexpand.reset(new ExpandQueue(LossGuide));
} else {
qexpand.reset(new ExpandQueue(DepthWise));
}
this->column_sampler.Init(num_columns, param.colsample_bynode,
param.colsample_bylevel, param.colsample_bytree);
dh::safe_cuda(cudaSetDevice(device_id));
@ -370,9 +311,9 @@ struct GPUHistMakerDevice {
return result.front();
}
std::vector<DeviceSplitCandidate> EvaluateLeftRightSplits(
ExpandEntry candidate, int left_nidx, int right_nidx,
const RegTree& tree) {
void EvaluateLeftRightSplits(
ExpandEntry candidate, int left_nidx, int right_nidx, const RegTree& tree,
common::Span<ExpandEntry> pinned_candidates_out) {
dh::TemporaryArray<DeviceSplitCandidate> splits_out(2);
GPUTrainingParam gpu_param(param);
auto left_sampled_features =
@ -412,12 +353,19 @@ struct GPUHistMakerDevice {
hist.GetNodeHistogram(right_nidx),
node_value_constraints[right_nidx],
dh::ToSpan(monotone_constraints)};
EvaluateSplits(dh::ToSpan(splits_out), left, right);
std::vector<DeviceSplitCandidate> result(2);
dh::safe_cuda(cudaMemcpy(result.data(), splits_out.data().get(),
sizeof(DeviceSplitCandidate) * splits_out.size(),
cudaMemcpyDeviceToHost));
return result;
auto d_splits_out = dh::ToSpan(splits_out);
EvaluateSplits(d_splits_out, left, right);
dh::TemporaryArray<ExpandEntry> entries(2);
auto d_entries = entries.data().get();
dh::LaunchN(device_id, 1, [=] __device__(size_t idx) {
d_entries[0] =
ExpandEntry(left_nidx, candidate.depth + 1, d_splits_out[0]);
d_entries[1] =
ExpandEntry(right_nidx, candidate.depth + 1, d_splits_out[1]);
});
dh::safe_cuda(cudaMemcpyAsync(
pinned_candidates_out.data(), entries.data().get(),
sizeof(ExpandEntry) * entries.size(), cudaMemcpyDeviceToHost));
}
void BuildHist(int nidx) {
@ -637,7 +585,7 @@ struct GPUHistMakerDevice {
tree[candidate.nid].RightChild());
}
void InitRoot(RegTree* p_tree, dh::AllReducer* reducer) {
ExpandEntry InitRoot(RegTree* p_tree, dh::AllReducer* reducer) {
constexpr bst_node_t kRootNIdx = 0;
dh::XGBCachingDeviceAllocator<char> alloc;
GradientPair root_sum = thrust::reduce(
@ -662,61 +610,66 @@ struct GPUHistMakerDevice {
// Generate first split
auto split = this->EvaluateRootSplit(root_sum);
qexpand->push(
ExpandEntry(kRootNIdx, p_tree->GetDepth(kRootNIdx), split, 0));
return ExpandEntry(kRootNIdx, p_tree->GetDepth(kRootNIdx), split);
}
void UpdateTree(HostDeviceVector<GradientPair>* gpair_all, DMatrix* p_fmat,
RegTree* p_tree, dh::AllReducer* reducer) {
auto& tree = *p_tree;
Driver driver(static_cast<TrainParam::TreeGrowPolicy>(param.grow_policy));
monitor.Start("Reset");
this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_);
monitor.Stop("Reset");
monitor.Start("InitRoot");
this->InitRoot(p_tree, reducer);
driver.Push({ this->InitRoot(p_tree, reducer) });
monitor.Stop("InitRoot");
auto timestamp = qexpand->size();
auto num_leaves = 1;
while (!qexpand->empty()) {
ExpandEntry candidate = qexpand->top();
qexpand->pop();
if (!candidate.IsValid(param, num_leaves)) {
continue;
}
this->ApplySplit(candidate, p_tree);
num_leaves++;
int left_child_nidx = tree[candidate.nid].LeftChild();
int right_child_nidx = tree[candidate.nid].RightChild();
// Only create child entries if needed
if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx),
num_leaves)) {
monitor.Start("UpdatePosition");
this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]);
monitor.Stop("UpdatePosition");
monitor.Start("BuildHist");
this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer);
monitor.Stop("BuildHist");
monitor.Start("EvaluateSplits");
auto splits = this->EvaluateLeftRightSplits(candidate, left_child_nidx,
right_child_nidx,
*p_tree);
monitor.Stop("EvaluateSplits");
qexpand->push(ExpandEntry(left_child_nidx,
tree.GetDepth(left_child_nidx), splits.at(0),
timestamp++));
qexpand->push(ExpandEntry(right_child_nidx,
tree.GetDepth(right_child_nidx),
splits.at(1), timestamp++));
// The set of leaves that can be expanded asynchronously
auto expand_set = driver.Pop();
while (!expand_set.empty()) {
auto new_candidates =
pinned.GetSpan<ExpandEntry>(expand_set.size() * 2, ExpandEntry());
for (auto i = 0ull; i < expand_set.size(); i++) {
auto candidate = expand_set.at(i);
if (!candidate.IsValid(param, num_leaves)) {
continue;
}
this->ApplySplit(candidate, p_tree);
num_leaves++;
int left_child_nidx = tree[candidate.nid].LeftChild();
int right_child_nidx = tree[candidate.nid].RightChild();
// Only create child entries if needed
if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx),
num_leaves)) {
monitor.Start("UpdatePosition");
this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]);
monitor.Stop("UpdatePosition");
monitor.Start("BuildHist");
this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer);
monitor.Stop("BuildHist");
monitor.Start("EvaluateSplits");
this->EvaluateLeftRightSplits(candidate, left_child_nidx,
right_child_nidx, *p_tree,
new_candidates.subspan(i * 2, 2));
monitor.Stop("EvaluateSplits");
} else {
// Set default
new_candidates[i * 2] = ExpandEntry();
new_candidates[i * 2 + 1] = ExpandEntry();
}
}
dh::safe_cuda(cudaDeviceSynchronize());
driver.Push(new_candidates.begin(), new_candidates.end());
expand_set = driver.Pop();
}
monitor.Start("FinalisePosition");

View File

@ -264,7 +264,7 @@ TEST_F(SerializationTest, CPUCoordDescent) {
}
#if defined(XGBOOST_USE_CUDA)
TEST_F(SerializationTest, GPUHist) {
TEST_F(SerializationTest, GpuHist) {
TestLearnerSerialization({{"booster", "gbtree"},
{"seed", "0"},
{"enable_experimental_json_serialization", "1"},
@ -441,7 +441,7 @@ TEST_F(LogitSerializationTest, CPUCoordDescent) {
}
#if defined(XGBOOST_USE_CUDA)
TEST_F(LogitSerializationTest, GPUHist) {
TEST_F(LogitSerializationTest, GpuHist) {
TestLearnerSerialization({{"booster", "gbtree"},
{"objective", "binary:logistic"},
{"seed", "0"},
@ -596,7 +596,7 @@ TEST_F(MultiClassesSerializationTest, CPUCoordDescent) {
}
#if defined(XGBOOST_USE_CUDA)
TEST_F(MultiClassesSerializationTest, GPUHist) {
TEST_F(MultiClassesSerializationTest, GpuHist) {
TestLearnerSerialization({{"booster", "gbtree"},
{"num_class", std::to_string(kClasses)},
{"seed", "0"},

View File

@ -0,0 +1,59 @@
#include <gtest/gtest.h>
#include "../../../../src/tree/gpu_hist/driver.cuh"
namespace xgboost {
namespace tree {
TEST(GpuHist, DriverDepthWise) {
Driver driver(TrainParam::kDepthWise);
EXPECT_TRUE(driver.Pop().empty());
DeviceSplitCandidate split;
split.loss_chg = 1.0f;
ExpandEntry root(0, 0, split);
driver.Push({root});
EXPECT_EQ(driver.Pop().front().nid, 0);
driver.Push({ExpandEntry{1, 1, split}});
driver.Push({ExpandEntry{2, 1, split}});
driver.Push({ExpandEntry{3, 2, split}});
// Should return entries from level 1
auto res = driver.Pop();
EXPECT_EQ(res.size(), 2);
for (auto &e : res) {
EXPECT_EQ(e.depth, 1);
}
res = driver.Pop();
EXPECT_EQ(res[0].depth, 2);
EXPECT_TRUE(driver.Pop().empty());
}
TEST(GpuHist, DriverLossGuided) {
DeviceSplitCandidate high_gain;
high_gain.loss_chg = 5.0f;
DeviceSplitCandidate low_gain;
low_gain.loss_chg = 1.0f;
Driver driver(TrainParam::kLossGuide);
EXPECT_TRUE(driver.Pop().empty());
ExpandEntry root(0, 0, high_gain);
driver.Push({root});
EXPECT_EQ(driver.Pop().front().nid, 0);
// Select high gain first
driver.Push({ExpandEntry{1, 1, low_gain}});
driver.Push({ExpandEntry{2, 2, high_gain}});
auto res = driver.Pop();
EXPECT_EQ(res.size(), 1);
EXPECT_EQ(res[0].nid, 2);
res = driver.Pop();
EXPECT_EQ(res.size(), 1);
EXPECT_EQ(res[0].nid, 1);
// If equal gain, use nid
driver.Push({ExpandEntry{2, 1, low_gain}});
driver.Push({ExpandEntry{1, 1, low_gain}});
res = driver.Pop();
EXPECT_EQ(res[0].nid, 1);
res = driver.Pop();
EXPECT_EQ(res[0].nid, 2);
}
} // namespace tree
} // namespace xgboost

View File

@ -40,7 +40,7 @@ class UpdaterTreeStatTest : public ::testing::Test {
};
#if defined(XGBOOST_USE_CUDA)
TEST_F(UpdaterTreeStatTest, GPUHist) {
TEST_F(UpdaterTreeStatTest, GpuHist) {
this->RunTest("grow_gpu_hist");
}
#endif // defined(XGBOOST_USE_CUDA)