Remove various synchronisations from cuda API calls, instrument monitor (#4205)

* Remove various synchronisations from cuda API calls, instrument monitor
with nvtx profiler ranges.
This commit is contained in:
Rory Mitchell 2019-03-10 15:01:23 +13:00 committed by GitHub
parent f83e62dca5
commit 4eeeded7d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 116 additions and 104 deletions

View File

@ -10,9 +10,12 @@ msvc_use_static_runtime()
# Options
## GPUs
option(USE_CUDA "Build with GPU acceleration" OFF)
option(USE_NVTX "Build with cuda profiling annotations. Developers only." OFF)
option(USE_NCCL "Build with multiple GPUs support" OFF)
set(GPU_COMPUTE_VER "" CACHE STRING
"Space separated list of compute versions to be built against, e.g. '35 61'")
set(NVTX_HEADER_DIR "" CACHE PATH
"Path to the stand-alone nvtx header")
## Bindings
option(JVM_BINDINGS "Build JVM bindings" OFF)
@ -175,6 +178,11 @@ if(USE_CUDA AND (NOT GENERATE_COMPILATION_DATABASE))
add_definitions(-DXGBOOST_USE_NCCL)
endif()
if(USE_NVTX)
cuda_include_directories("${NVTX_HEADER_DIR}")
add_definitions(-DXGBOOST_USE_NVTX)
endif()
set(GENCODE_FLAGS "")
format_gencode_flags("${GPU_COMPUTE_VER}" GENCODE_FLAGS)
message("cuda architecture flags: ${GENCODE_FLAGS}")
@ -190,6 +198,7 @@ if(USE_CUDA AND (NOT GENERATE_COMPILATION_DATABASE))
link_directories(${NCCL_LIBRARY})
target_link_libraries(gpuxgboost ${NCCL_LIB_NAME})
endif()
list(APPEND LINK_LIBRARIES gpuxgboost)
elseif (USE_CUDA AND GENERATE_COMPILATION_DATABASE)

View File

@ -195,6 +195,10 @@ Training time time on 1,000,000 rows x 50 columns with 500 boosting iterations a
See `GPU Accelerated XGBoost <https://xgboost.ai/2016/12/14/GPU-accelerated-xgboost.html>`_ and `Updates to the XGBoost GPU algorithms <https://xgboost.ai/2018/07/04/gpu-xgboost-update.html>`_ for additional performance benchmarks of the ``gpu_exact`` and ``gpu_hist`` tree methods.
Developer notes
==========
The application may be profiled with annotations by specifying USE_NTVX to cmake and providing the path to the stand-alone nvtx header via NVTX_HEADER_DIR. Regions covered by the 'Monitor' class in cuda code will automatically appear in the nsight profiler.
**********
References
**********

View File

@ -308,7 +308,7 @@ class DVec {
}
safe_cuda(cudaSetDevice(this->DeviceIdx()));
if (other.DeviceIdx() == this->DeviceIdx()) {
dh::safe_cuda(cudaMemcpy(this->Data(), other.Data(),
dh::safe_cuda(cudaMemcpyAsync(this->Data(), other.Data(),
other.Size() * sizeof(T),
cudaMemcpyDeviceToDevice));
} else {
@ -338,7 +338,7 @@ class DVec {
throw std::runtime_error(
"Cannot copy assign vector to dvec, sizes are different");
}
safe_cuda(cudaMemcpy(this->Data(), begin.get(), Size() * sizeof(T),
safe_cuda(cudaMemcpyAsync(this->Data(), begin.get(), Size() * sizeof(T),
cudaMemcpyDefault));
}
};

View File

@ -290,14 +290,14 @@ struct GPUSketcher {
offset_vec[row_begin_ + batch_row_begin];
// copy the batch to the GPU
dh::safe_cuda
(cudaMemcpy(entries_.data().get(),
(cudaMemcpyAsync(entries_.data().get(),
data_vec.data() + offset_vec[row_begin_ + batch_row_begin],
n_entries * sizeof(Entry), cudaMemcpyDefault));
// copy the weights if necessary
if (has_weights_) {
const auto& weights_vec = info.weights_.HostVector();
dh::safe_cuda
(cudaMemcpy(weights_.data().get(),
(cudaMemcpyAsync(weights_.data().get(),
weights_vec.data() + row_begin_ + batch_row_begin,
batch_nrows * sizeof(bst_float), cudaMemcpyDefault));
}
@ -315,15 +315,11 @@ struct GPUSketcher {
has_weights_ ? weights_.data().get() : nullptr, entries_.data().get(),
gpu_batch_nrows_, num_cols_,
offset_vec[row_begin_ + batch_row_begin], batch_nrows);
dh::safe_cuda(cudaGetLastError()); // NOLINT
dh::safe_cuda(cudaDeviceSynchronize()); // NOLINT
for (int icol = 0; icol < num_cols_; ++icol) {
FindColumnCuts(batch_nrows, icol);
}
dh::safe_cuda(cudaDeviceSynchronize()); // NOLINT
// add cuts into sketches
thrust::copy(cuts_d_.begin(), cuts_d_.end(), cuts_h_.begin());
for (int icol = 0; icol < num_cols_; ++icol) {

View File

@ -74,14 +74,14 @@ struct HostDeviceVectorImpl {
// TODO(canonizer): avoid full copy of host data
LazySyncDevice(GPUAccess::kWrite);
SetDevice();
dh::safe_cuda(cudaMemcpy(data_.data().get(), begin + start_,
dh::safe_cuda(cudaMemcpyAsync(data_.data().get(), begin + start_,
data_.size() * sizeof(T), cudaMemcpyDefault));
}
void GatherTo(thrust::device_ptr<T> begin) {
LazySyncDevice(GPUAccess::kRead);
SetDevice();
dh::safe_cuda(cudaMemcpy(begin.get() + start_, data_.data().get(),
dh::safe_cuda(cudaMemcpyAsync(begin.get() + start_, data_.data().get(),
proper_size_ * sizeof(T), cudaMemcpyDefault));
}
@ -97,7 +97,7 @@ struct HostDeviceVectorImpl {
LazySyncDevice(GPUAccess::kWrite);
other->LazySyncDevice(GPUAccess::kRead);
SetDevice();
dh::safe_cuda(cudaMemcpy(data_.data().get(), other->data_.data().get(),
dh::safe_cuda(cudaMemcpyAsync(data_.data().get(), other->data_.data().get(),
data_.size() * sizeof(T), cudaMemcpyDefault));
}

View File

@ -8,7 +8,9 @@
#include <map>
#include <string>
#include "common.h"
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
#include <nvToolsExt.h>
#endif
namespace xgboost {
namespace common {
@ -45,9 +47,11 @@ struct Timer {
*/
struct Monitor {
private:
struct Statistics {
Timer timer;
size_t count{0};
uint64_t nvtx_id;
};
std::string label = "";
std::map<std::string, Statistics> statistics_map;
@ -75,35 +79,37 @@ struct Monitor {
}
self_timer.Stop();
}
void Init(std::string label) {
this->label = label;
}
void Start(const std::string &name) { statistics_map[name].timer.Start(); }
void Start(const std::string &name, GPUSet devices) {
void Init(std::string label) { this->label = label; }
void Start(const std::string &name) {
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
#ifdef __CUDACC__
for (auto device : devices) {
cudaSetDevice(device);
cudaDeviceSynchronize();
}
#endif // __CUDACC__
statistics_map[name].timer.Start();
}
statistics_map[name].timer.Start();
}
void Stop(const std::string &name) {
statistics_map[name].timer.Stop();
statistics_map[name].count++;
}
void Stop(const std::string &name, GPUSet devices) {
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
#ifdef __CUDACC__
for (auto device : devices) {
cudaSetDevice(device);
cudaDeviceSynchronize();
}
#endif // __CUDACC__
auto &stats = statistics_map[name];
stats.timer.Stop();
stats.count++;
}
}
void StartCuda(const std::string &name) {
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
auto &stats = statistics_map[name];
stats.timer.Start();
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
stats.nvtx_id = nvtxRangeStartA(name.c_str());
#endif
}
}
void StopCuda(const std::string &name) {
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
auto &stats = statistics_map[name];
stats.timer.Stop();
stats.count++;
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
nvtxRangeEnd(stats.nvtx_id);
#endif
}
this->Stop(name);
}
};
} // namespace common

View File

@ -145,8 +145,6 @@ class Transform {
static_cast<int>(dh::DivRoundUp(*(range_.end()), kBlockThreads));
detail::LaunchCUDAKernel<<<GRID_SIZE, kBlockThreads>>>(
_func, shard_range, UnpackHDV(_vectors, device)...);
dh::safe_cuda(cudaGetLastError());
dh::safe_cuda(cudaDeviceSynchronize());
}
}
#else

View File

@ -252,17 +252,17 @@ class GPUPredictor : public xgboost::Predictor {
size_t tree_begin, size_t tree_end) {
dh::safe_cuda(cudaSetDevice(device_));
nodes.resize(h_nodes.size());
dh::safe_cuda(cudaMemcpy(dh::Raw(nodes), h_nodes.data(),
dh::safe_cuda(cudaMemcpyAsync(dh::Raw(nodes), h_nodes.data(),
sizeof(DevicePredictionNode) * h_nodes.size(),
cudaMemcpyHostToDevice));
tree_segments.resize(h_tree_segments.size());
dh::safe_cuda(cudaMemcpy(dh::Raw(tree_segments), h_tree_segments.data(),
dh::safe_cuda(cudaMemcpyAsync(dh::Raw(tree_segments), h_tree_segments.data(),
sizeof(size_t) * h_tree_segments.size(),
cudaMemcpyHostToDevice));
tree_group.resize(model.tree_info.size());
dh::safe_cuda(cudaMemcpy(dh::Raw(tree_group), model.tree_info.data(),
dh::safe_cuda(cudaMemcpyAsync(dh::Raw(tree_group), model.tree_info.data(),
sizeof(int) * model.tree_info.size(),
cudaMemcpyHostToDevice));
@ -288,9 +288,6 @@ class GPUPredictor : public xgboost::Predictor {
dh::ToSpan(tree_group), batch.offset.DeviceSpan(device_),
batch.data.DeviceSpan(device_), tree_begin, tree_end, info.num_col_,
num_rows, entry_start, use_shared, model.param.num_output_group);
dh::safe_cuda(cudaGetLastError());
dh::safe_cuda(cudaDeviceSynchronize());
}
int device_;

View File

@ -289,7 +289,9 @@ struct DeviceHistogram {
void Reset() {
dh::safe_cuda(cudaSetDevice(device_id_));
data.resize(0);
dh::safe_cuda(cudaMemsetAsync(
data.data().get(), 0,
data.size() * sizeof(typename decltype(data)::value_type)));
nidx_map.clear();
}
@ -299,20 +301,25 @@ struct DeviceHistogram {
void AllocateHistogram(int nidx) {
if (HistogramExists(nidx)) return;
size_t current_size =
nidx_map.size() * n_bins * 2; // Number of items currently used in data
dh::safe_cuda(cudaSetDevice(device_id_));
if (data.size() > kStopGrowingSize) {
if (data.size() >= kStopGrowingSize) {
// Recycle histogram memory
std::pair<int, size_t> old_entry = *nidx_map.begin();
nidx_map.erase(old_entry.first);
dh::safe_cuda(cudaMemset(data.data().get() + old_entry.second, 0,
dh::safe_cuda(cudaMemsetAsync(data.data().get() + old_entry.second, 0,
n_bins * sizeof(GradientSumT)));
nidx_map[nidx] = old_entry.second;
} else {
// Append new node histogram
nidx_map[nidx] = data.size();
// x 2: Hess and Grad.
data.resize(data.size() + (n_bins * 2));
nidx_map[nidx] = current_size;
if (data.size() < current_size + n_bins * 2) {
size_t new_size = current_size * 2; // Double in size
new_size = std::max(static_cast<size_t>(n_bins * 2),
new_size); // Have at least one histogram
data.resize(new_size);
}
}
}
@ -610,20 +617,20 @@ struct DeviceShard {
feature_set_d.resize(feature_set.size());
auto d_features = common::Span<int>(feature_set_d.data().get(),
feature_set_d.size());
dh::safe_cuda(cudaMemcpy(d_features.data(), feature_set.data(),
dh::safe_cuda(cudaMemcpyAsync(d_features.data(), feature_set.data(),
d_features.size_bytes(), cudaMemcpyDefault));
DeviceNodeStats node(node_sum_gradients[nidx], nidx, param);
// One block for each feature
int constexpr BLOCK_THREADS = 256;
EvaluateSplitKernel<BLOCK_THREADS, GradientSumT>
<<<uint32_t(feature_set.size()), BLOCK_THREADS, 0>>>
(hist.GetNodeHistogram(nidx), d_features, node,
cut_.feature_segments.GetSpan(), cut_.min_fvalue.GetSpan(),
cut_.gidx_fvalue_map.GetSpan(), GPUTrainingParam(param),
d_split_candidates, value_constraint, monotone_constraints.GetSpan());
<<<uint32_t(feature_set.size()), BLOCK_THREADS, 0>>>(
hist.GetNodeHistogram(nidx), d_features, node,
cut_.feature_segments.GetSpan(), cut_.min_fvalue.GetSpan(),
cut_.gidx_fvalue_map.GetSpan(), GPUTrainingParam(param),
d_split_candidates, value_constraint,
monotone_constraints.GetSpan());
dh::safe_cuda(cudaDeviceSynchronize());
std::vector<DeviceSplitCandidate> split_candidates(feature_set.size());
dh::safe_cuda(cudaMemcpy(split_candidates.data(), d_split_candidates.data(),
split_candidates.size() * sizeof(DeviceSplitCandidate),
@ -725,20 +732,21 @@ struct DeviceShard {
common::Span<bst_uint>(ridx.Current() + segment.begin, segment.Size()),
common::Span<bst_uint>(ridx.other() + segment.begin, segment.Size()),
left_nidx, right_nidx, left_count);
// Copy back key
dh::safe_cuda(cudaMemcpy(
position.Current() + segment.begin, position.other() + segment.begin,
segment.Size() * sizeof(int), cudaMemcpyDeviceToDevice));
// Copy back value
dh::safe_cuda(cudaMemcpy(
ridx.Current() + segment.begin, ridx.other() + segment.begin,
segment.Size() * sizeof(bst_uint), cudaMemcpyDeviceToDevice));
// Copy back key/value
const auto d_position_current = position.Current() + segment.begin;
const auto d_position_other = position.other() + segment.begin;
const auto d_ridx_current = ridx.Current() + segment.begin;
const auto d_ridx_other = ridx.other() + segment.begin;
dh::LaunchN(device_id_, segment.Size(), [=] __device__(size_t idx) {
d_position_current[idx] = d_position_other[idx];
d_ridx_current[idx] = d_ridx_other[idx];
});
}
void UpdatePredictionCache(bst_float* out_preds_d) {
dh::safe_cuda(cudaSetDevice(device_id_));
if (!prediction_cache_initialised) {
dh::safe_cuda(cudaMemcpy(
dh::safe_cuda(cudaMemcpyAsync(
prediction_cache.Data(), out_preds_d,
prediction_cache.Size() * sizeof(bst_float), cudaMemcpyDefault));
}
@ -746,7 +754,7 @@ struct DeviceShard {
CalcWeightTrainParam param_d(param);
dh::safe_cuda(cudaMemcpy(node_sum_gradients_d.Data(),
dh::safe_cuda(cudaMemcpyAsync(node_sum_gradients_d.Data(),
node_sum_gradients.data(),
sizeof(GradientPair) * node_sum_gradients.size(),
cudaMemcpyHostToDevice));
@ -925,9 +933,6 @@ inline void DeviceShard<GradientSumT>::CreateHistIndices(const SparsePage& row_b
batch_row_begin, batch_nrows,
row_ptrs[batch_row_begin],
row_stride, null_gidx_value);
dh::safe_cuda(cudaGetLastError());
dh::safe_cuda(cudaDeviceSynchronize());
}
// free the memory that is no longer needed
@ -965,7 +970,7 @@ class GPUHistMakerSpecialised{
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) {
monitor_.Start("Update", dist_.Devices());
monitor_.StartCuda("Update");
// rescale learning rate according to size of trees
float lr = param_.learning_rate;
param_.learning_rate = lr / trees.size();
@ -980,7 +985,7 @@ class GPUHistMakerSpecialised{
LOG(FATAL) << "Exception in gpu_hist: " << e.what() << std::endl;
}
param_.learning_rate = lr;
monitor_.Stop("Update", dist_.Devices());
monitor_.StopCuda("Update");
}
void InitDataOnce(DMatrix* dmat) {
@ -1010,17 +1015,17 @@ class GPUHistMakerSpecialised{
});
// Find the cuts.
monitor_.Start("Quantiles", dist_.Devices());
monitor_.StartCuda("Quantiles");
common::DeviceSketch(batch, *info_, param_, &hmat_, hist_maker_param_.gpu_batch_nrows);
n_bins_ = hmat_.row_ptr.back();
monitor_.Stop("Quantiles", dist_.Devices());
monitor_.StopCuda("Quantiles");
monitor_.Start("BinningCompression", dist_.Devices());
monitor_.StartCuda("BinningCompression");
dh::ExecuteIndexShards(&shards_, [&](int idx,
std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
shard->InitCompressedData(hmat_, batch);
});
monitor_.Stop("BinningCompression", dist_.Devices());
monitor_.StopCuda("BinningCompression");
++batch_iter;
CHECK(batch_iter.AtEnd()) << "External memory not supported";
@ -1030,16 +1035,16 @@ class GPUHistMakerSpecialised{
void InitData(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat) {
if (!initialised_) {
monitor_.Start("InitDataOnce", dist_.Devices());
monitor_.StartCuda("InitDataOnce");
this->InitDataOnce(dmat);
monitor_.Stop("InitDataOnce", dist_.Devices());
monitor_.StopCuda("InitDataOnce");
}
column_sampler_.Init(info_->num_col_, param_.colsample_bynode,
param_.colsample_bylevel, param_.colsample_bytree);
// Copy gpair & reset memory
monitor_.Start("InitDataReset", dist_.Devices());
monitor_.StartCuda("InitDataReset");
gpair->Reshard(dist_);
dh::ExecuteIndexShards(
@ -1047,13 +1052,12 @@ class GPUHistMakerSpecialised{
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
shard->Reset(gpair);
});
monitor_.Stop("InitDataReset", dist_.Devices());
monitor_.StopCuda("InitDataReset");
}
void AllReduceHist(int nidx) {
if (shards_.size() == 1 && !rabit::IsDistributed())
return;
monitor_.Start("AllReduce");
if (shards_.size() == 1 && !rabit::IsDistributed()) return;
monitor_.StartCuda("AllReduce");
reducer_.GroupStart();
for (auto& shard : shards_) {
@ -1067,7 +1071,7 @@ class GPUHistMakerSpecialised{
reducer_.GroupEnd();
reducer_.Synchronize();
monitor_.Stop("AllReduce");
monitor_.StopCuda("AllReduce");
}
/**
@ -1250,12 +1254,12 @@ class GPUHistMakerSpecialised{
RegTree* p_tree) {
auto& tree = *p_tree;
monitor_.Start("InitData", dist_.Devices());
monitor_.StartCuda("InitData");
this->InitData(gpair, p_fmat);
monitor_.Stop("InitData", dist_.Devices());
monitor_.Start("InitRoot", dist_.Devices());
monitor_.StopCuda("InitData");
monitor_.StartCuda("InitRoot");
this->InitRoot(p_tree);
monitor_.Stop("InitRoot", dist_.Devices());
monitor_.StopCuda("InitRoot");
auto timestamp = qexpand_->size();
auto num_leaves = 1;
@ -1266,9 +1270,9 @@ class GPUHistMakerSpecialised{
if (!candidate.IsValid(param_, num_leaves)) continue;
this->ApplySplit(candidate, p_tree);
monitor_.Start("UpdatePosition", dist_.Devices());
monitor_.StartCuda("UpdatePosition");
this->UpdatePosition(candidate, p_tree);
monitor_.Stop("UpdatePosition", dist_.Devices());
monitor_.StopCuda("UpdatePosition");
num_leaves++;
int left_child_nidx = tree[candidate.nid].LeftChild();
@ -1277,32 +1281,30 @@ class GPUHistMakerSpecialised{
// Only create child entries if needed
if (ExpandEntry::ChildIsValid(param_, tree.GetDepth(left_child_nidx),
num_leaves)) {
monitor_.Start("BuildHist", dist_.Devices());
monitor_.StartCuda("BuildHist");
this->BuildHistLeftRight(candidate.nid, left_child_nidx,
right_child_nidx);
monitor_.Stop("BuildHist", dist_.Devices());
monitor_.StopCuda("BuildHist");
monitor_.Start("EvaluateSplits", dist_.Devices());
auto left_child_split =
this->EvaluateSplit(left_child_nidx, p_tree);
auto right_child_split =
this->EvaluateSplit(right_child_nidx, p_tree);
monitor_.StartCuda("EvaluateSplits");
auto left_child_split = this->EvaluateSplit(left_child_nidx, p_tree);
auto right_child_split = this->EvaluateSplit(right_child_nidx, p_tree);
qexpand_->push(ExpandEntry(left_child_nidx,
tree.GetDepth(left_child_nidx), left_child_split,
timestamp++));
tree.GetDepth(left_child_nidx),
left_child_split, timestamp++));
qexpand_->push(ExpandEntry(right_child_nidx,
tree.GetDepth(right_child_nidx), right_child_split,
timestamp++));
monitor_.Stop("EvaluateSplits", dist_.Devices());
tree.GetDepth(right_child_nidx),
right_child_split, timestamp++));
monitor_.StopCuda("EvaluateSplits");
}
}
}
bool UpdatePredictionCache(
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) {
monitor_.Start("UpdatePredictionCache", dist_.Devices());
if (shards_.empty() || p_last_fmat_ == nullptr || p_last_fmat_ != data)
return false;
monitor_.StartCuda("UpdatePredictionCache");
p_out_preds->Reshard(dist_.Devices());
dh::ExecuteIndexShards(
&shards_,
@ -1310,7 +1312,7 @@ class GPUHistMakerSpecialised{
shard->UpdatePredictionCache(
p_out_preds->DevicePointer(shard->device_id_));
});
monitor_.Stop("UpdatePredictionCache", dist_.Devices());
monitor_.StopCuda("UpdatePredictionCache");
return true;
}