Remove various synchronisations from cuda API calls, instrument monitor (#4205)
* Remove various synchronisations from cuda API calls, instrument monitor with nvtx profiler ranges.
This commit is contained in:
@@ -289,7 +289,9 @@ struct DeviceHistogram {
|
||||
|
||||
void Reset() {
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
data.resize(0);
|
||||
dh::safe_cuda(cudaMemsetAsync(
|
||||
data.data().get(), 0,
|
||||
data.size() * sizeof(typename decltype(data)::value_type)));
|
||||
nidx_map.clear();
|
||||
}
|
||||
|
||||
@@ -299,20 +301,25 @@ struct DeviceHistogram {
|
||||
|
||||
void AllocateHistogram(int nidx) {
|
||||
if (HistogramExists(nidx)) return;
|
||||
|
||||
size_t current_size =
|
||||
nidx_map.size() * n_bins * 2; // Number of items currently used in data
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
if (data.size() > kStopGrowingSize) {
|
||||
if (data.size() >= kStopGrowingSize) {
|
||||
// Recycle histogram memory
|
||||
std::pair<int, size_t> old_entry = *nidx_map.begin();
|
||||
nidx_map.erase(old_entry.first);
|
||||
dh::safe_cuda(cudaMemset(data.data().get() + old_entry.second, 0,
|
||||
dh::safe_cuda(cudaMemsetAsync(data.data().get() + old_entry.second, 0,
|
||||
n_bins * sizeof(GradientSumT)));
|
||||
nidx_map[nidx] = old_entry.second;
|
||||
} else {
|
||||
// Append new node histogram
|
||||
nidx_map[nidx] = data.size();
|
||||
// x 2: Hess and Grad.
|
||||
data.resize(data.size() + (n_bins * 2));
|
||||
nidx_map[nidx] = current_size;
|
||||
if (data.size() < current_size + n_bins * 2) {
|
||||
size_t new_size = current_size * 2; // Double in size
|
||||
new_size = std::max(static_cast<size_t>(n_bins * 2),
|
||||
new_size); // Have at least one histogram
|
||||
data.resize(new_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -610,20 +617,20 @@ struct DeviceShard {
|
||||
feature_set_d.resize(feature_set.size());
|
||||
auto d_features = common::Span<int>(feature_set_d.data().get(),
|
||||
feature_set_d.size());
|
||||
dh::safe_cuda(cudaMemcpy(d_features.data(), feature_set.data(),
|
||||
dh::safe_cuda(cudaMemcpyAsync(d_features.data(), feature_set.data(),
|
||||
d_features.size_bytes(), cudaMemcpyDefault));
|
||||
DeviceNodeStats node(node_sum_gradients[nidx], nidx, param);
|
||||
|
||||
// One block for each feature
|
||||
int constexpr BLOCK_THREADS = 256;
|
||||
EvaluateSplitKernel<BLOCK_THREADS, GradientSumT>
|
||||
<<<uint32_t(feature_set.size()), BLOCK_THREADS, 0>>>
|
||||
(hist.GetNodeHistogram(nidx), d_features, node,
|
||||
cut_.feature_segments.GetSpan(), cut_.min_fvalue.GetSpan(),
|
||||
cut_.gidx_fvalue_map.GetSpan(), GPUTrainingParam(param),
|
||||
d_split_candidates, value_constraint, monotone_constraints.GetSpan());
|
||||
<<<uint32_t(feature_set.size()), BLOCK_THREADS, 0>>>(
|
||||
hist.GetNodeHistogram(nidx), d_features, node,
|
||||
cut_.feature_segments.GetSpan(), cut_.min_fvalue.GetSpan(),
|
||||
cut_.gidx_fvalue_map.GetSpan(), GPUTrainingParam(param),
|
||||
d_split_candidates, value_constraint,
|
||||
monotone_constraints.GetSpan());
|
||||
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
std::vector<DeviceSplitCandidate> split_candidates(feature_set.size());
|
||||
dh::safe_cuda(cudaMemcpy(split_candidates.data(), d_split_candidates.data(),
|
||||
split_candidates.size() * sizeof(DeviceSplitCandidate),
|
||||
@@ -725,20 +732,21 @@ struct DeviceShard {
|
||||
common::Span<bst_uint>(ridx.Current() + segment.begin, segment.Size()),
|
||||
common::Span<bst_uint>(ridx.other() + segment.begin, segment.Size()),
|
||||
left_nidx, right_nidx, left_count);
|
||||
// Copy back key
|
||||
dh::safe_cuda(cudaMemcpy(
|
||||
position.Current() + segment.begin, position.other() + segment.begin,
|
||||
segment.Size() * sizeof(int), cudaMemcpyDeviceToDevice));
|
||||
// Copy back value
|
||||
dh::safe_cuda(cudaMemcpy(
|
||||
ridx.Current() + segment.begin, ridx.other() + segment.begin,
|
||||
segment.Size() * sizeof(bst_uint), cudaMemcpyDeviceToDevice));
|
||||
// Copy back key/value
|
||||
const auto d_position_current = position.Current() + segment.begin;
|
||||
const auto d_position_other = position.other() + segment.begin;
|
||||
const auto d_ridx_current = ridx.Current() + segment.begin;
|
||||
const auto d_ridx_other = ridx.other() + segment.begin;
|
||||
dh::LaunchN(device_id_, segment.Size(), [=] __device__(size_t idx) {
|
||||
d_position_current[idx] = d_position_other[idx];
|
||||
d_ridx_current[idx] = d_ridx_other[idx];
|
||||
});
|
||||
}
|
||||
|
||||
void UpdatePredictionCache(bst_float* out_preds_d) {
|
||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
||||
if (!prediction_cache_initialised) {
|
||||
dh::safe_cuda(cudaMemcpy(
|
||||
dh::safe_cuda(cudaMemcpyAsync(
|
||||
prediction_cache.Data(), out_preds_d,
|
||||
prediction_cache.Size() * sizeof(bst_float), cudaMemcpyDefault));
|
||||
}
|
||||
@@ -746,7 +754,7 @@ struct DeviceShard {
|
||||
|
||||
CalcWeightTrainParam param_d(param);
|
||||
|
||||
dh::safe_cuda(cudaMemcpy(node_sum_gradients_d.Data(),
|
||||
dh::safe_cuda(cudaMemcpyAsync(node_sum_gradients_d.Data(),
|
||||
node_sum_gradients.data(),
|
||||
sizeof(GradientPair) * node_sum_gradients.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
@@ -925,9 +933,6 @@ inline void DeviceShard<GradientSumT>::CreateHistIndices(const SparsePage& row_b
|
||||
batch_row_begin, batch_nrows,
|
||||
row_ptrs[batch_row_begin],
|
||||
row_stride, null_gidx_value);
|
||||
|
||||
dh::safe_cuda(cudaGetLastError());
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
// free the memory that is no longer needed
|
||||
@@ -965,7 +970,7 @@ class GPUHistMakerSpecialised{
|
||||
|
||||
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
const std::vector<RegTree*>& trees) {
|
||||
monitor_.Start("Update", dist_.Devices());
|
||||
monitor_.StartCuda("Update");
|
||||
// rescale learning rate according to size of trees
|
||||
float lr = param_.learning_rate;
|
||||
param_.learning_rate = lr / trees.size();
|
||||
@@ -980,7 +985,7 @@ class GPUHistMakerSpecialised{
|
||||
LOG(FATAL) << "Exception in gpu_hist: " << e.what() << std::endl;
|
||||
}
|
||||
param_.learning_rate = lr;
|
||||
monitor_.Stop("Update", dist_.Devices());
|
||||
monitor_.StopCuda("Update");
|
||||
}
|
||||
|
||||
void InitDataOnce(DMatrix* dmat) {
|
||||
@@ -1010,17 +1015,17 @@ class GPUHistMakerSpecialised{
|
||||
});
|
||||
|
||||
// Find the cuts.
|
||||
monitor_.Start("Quantiles", dist_.Devices());
|
||||
monitor_.StartCuda("Quantiles");
|
||||
common::DeviceSketch(batch, *info_, param_, &hmat_, hist_maker_param_.gpu_batch_nrows);
|
||||
n_bins_ = hmat_.row_ptr.back();
|
||||
monitor_.Stop("Quantiles", dist_.Devices());
|
||||
monitor_.StopCuda("Quantiles");
|
||||
|
||||
monitor_.Start("BinningCompression", dist_.Devices());
|
||||
monitor_.StartCuda("BinningCompression");
|
||||
dh::ExecuteIndexShards(&shards_, [&](int idx,
|
||||
std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
||||
shard->InitCompressedData(hmat_, batch);
|
||||
});
|
||||
monitor_.Stop("BinningCompression", dist_.Devices());
|
||||
monitor_.StopCuda("BinningCompression");
|
||||
++batch_iter;
|
||||
CHECK(batch_iter.AtEnd()) << "External memory not supported";
|
||||
|
||||
@@ -1030,16 +1035,16 @@ class GPUHistMakerSpecialised{
|
||||
|
||||
void InitData(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat) {
|
||||
if (!initialised_) {
|
||||
monitor_.Start("InitDataOnce", dist_.Devices());
|
||||
monitor_.StartCuda("InitDataOnce");
|
||||
this->InitDataOnce(dmat);
|
||||
monitor_.Stop("InitDataOnce", dist_.Devices());
|
||||
monitor_.StopCuda("InitDataOnce");
|
||||
}
|
||||
|
||||
column_sampler_.Init(info_->num_col_, param_.colsample_bynode,
|
||||
param_.colsample_bylevel, param_.colsample_bytree);
|
||||
|
||||
// Copy gpair & reset memory
|
||||
monitor_.Start("InitDataReset", dist_.Devices());
|
||||
monitor_.StartCuda("InitDataReset");
|
||||
|
||||
gpair->Reshard(dist_);
|
||||
dh::ExecuteIndexShards(
|
||||
@@ -1047,13 +1052,12 @@ class GPUHistMakerSpecialised{
|
||||
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
||||
shard->Reset(gpair);
|
||||
});
|
||||
monitor_.Stop("InitDataReset", dist_.Devices());
|
||||
monitor_.StopCuda("InitDataReset");
|
||||
}
|
||||
|
||||
void AllReduceHist(int nidx) {
|
||||
if (shards_.size() == 1 && !rabit::IsDistributed())
|
||||
return;
|
||||
monitor_.Start("AllReduce");
|
||||
if (shards_.size() == 1 && !rabit::IsDistributed()) return;
|
||||
monitor_.StartCuda("AllReduce");
|
||||
|
||||
reducer_.GroupStart();
|
||||
for (auto& shard : shards_) {
|
||||
@@ -1067,7 +1071,7 @@ class GPUHistMakerSpecialised{
|
||||
reducer_.GroupEnd();
|
||||
reducer_.Synchronize();
|
||||
|
||||
monitor_.Stop("AllReduce");
|
||||
monitor_.StopCuda("AllReduce");
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -1250,12 +1254,12 @@ class GPUHistMakerSpecialised{
|
||||
RegTree* p_tree) {
|
||||
auto& tree = *p_tree;
|
||||
|
||||
monitor_.Start("InitData", dist_.Devices());
|
||||
monitor_.StartCuda("InitData");
|
||||
this->InitData(gpair, p_fmat);
|
||||
monitor_.Stop("InitData", dist_.Devices());
|
||||
monitor_.Start("InitRoot", dist_.Devices());
|
||||
monitor_.StopCuda("InitData");
|
||||
monitor_.StartCuda("InitRoot");
|
||||
this->InitRoot(p_tree);
|
||||
monitor_.Stop("InitRoot", dist_.Devices());
|
||||
monitor_.StopCuda("InitRoot");
|
||||
|
||||
auto timestamp = qexpand_->size();
|
||||
auto num_leaves = 1;
|
||||
@@ -1266,9 +1270,9 @@ class GPUHistMakerSpecialised{
|
||||
if (!candidate.IsValid(param_, num_leaves)) continue;
|
||||
|
||||
this->ApplySplit(candidate, p_tree);
|
||||
monitor_.Start("UpdatePosition", dist_.Devices());
|
||||
monitor_.StartCuda("UpdatePosition");
|
||||
this->UpdatePosition(candidate, p_tree);
|
||||
monitor_.Stop("UpdatePosition", dist_.Devices());
|
||||
monitor_.StopCuda("UpdatePosition");
|
||||
num_leaves++;
|
||||
|
||||
int left_child_nidx = tree[candidate.nid].LeftChild();
|
||||
@@ -1277,32 +1281,30 @@ class GPUHistMakerSpecialised{
|
||||
// Only create child entries if needed
|
||||
if (ExpandEntry::ChildIsValid(param_, tree.GetDepth(left_child_nidx),
|
||||
num_leaves)) {
|
||||
monitor_.Start("BuildHist", dist_.Devices());
|
||||
monitor_.StartCuda("BuildHist");
|
||||
this->BuildHistLeftRight(candidate.nid, left_child_nidx,
|
||||
right_child_nidx);
|
||||
monitor_.Stop("BuildHist", dist_.Devices());
|
||||
monitor_.StopCuda("BuildHist");
|
||||
|
||||
monitor_.Start("EvaluateSplits", dist_.Devices());
|
||||
auto left_child_split =
|
||||
this->EvaluateSplit(left_child_nidx, p_tree);
|
||||
auto right_child_split =
|
||||
this->EvaluateSplit(right_child_nidx, p_tree);
|
||||
monitor_.StartCuda("EvaluateSplits");
|
||||
auto left_child_split = this->EvaluateSplit(left_child_nidx, p_tree);
|
||||
auto right_child_split = this->EvaluateSplit(right_child_nidx, p_tree);
|
||||
qexpand_->push(ExpandEntry(left_child_nidx,
|
||||
tree.GetDepth(left_child_nidx), left_child_split,
|
||||
timestamp++));
|
||||
tree.GetDepth(left_child_nidx),
|
||||
left_child_split, timestamp++));
|
||||
qexpand_->push(ExpandEntry(right_child_nidx,
|
||||
tree.GetDepth(right_child_nidx), right_child_split,
|
||||
timestamp++));
|
||||
monitor_.Stop("EvaluateSplits", dist_.Devices());
|
||||
tree.GetDepth(right_child_nidx),
|
||||
right_child_split, timestamp++));
|
||||
monitor_.StopCuda("EvaluateSplits");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(
|
||||
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) {
|
||||
monitor_.Start("UpdatePredictionCache", dist_.Devices());
|
||||
if (shards_.empty() || p_last_fmat_ == nullptr || p_last_fmat_ != data)
|
||||
return false;
|
||||
monitor_.StartCuda("UpdatePredictionCache");
|
||||
p_out_preds->Reshard(dist_.Devices());
|
||||
dh::ExecuteIndexShards(
|
||||
&shards_,
|
||||
@@ -1310,7 +1312,7 @@ class GPUHistMakerSpecialised{
|
||||
shard->UpdatePredictionCache(
|
||||
p_out_preds->DevicePointer(shard->device_id_));
|
||||
});
|
||||
monitor_.Stop("UpdatePredictionCache", dist_.Devices());
|
||||
monitor_.StopCuda("UpdatePredictionCache");
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user