Remove various synchronisations from cuda API calls, instrument monitor (#4205)

* Remove various synchronisations from cuda API calls, instrument monitor
with nvtx profiler ranges.
This commit is contained in:
Rory Mitchell 2019-03-10 15:01:23 +13:00 committed by GitHub
parent f83e62dca5
commit 4eeeded7d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 116 additions and 104 deletions

View File

@ -10,9 +10,12 @@ msvc_use_static_runtime()
# Options # Options
## GPUs ## GPUs
option(USE_CUDA "Build with GPU acceleration" OFF) option(USE_CUDA "Build with GPU acceleration" OFF)
option(USE_NVTX "Build with cuda profiling annotations. Developers only." OFF)
option(USE_NCCL "Build with multiple GPUs support" OFF) option(USE_NCCL "Build with multiple GPUs support" OFF)
set(GPU_COMPUTE_VER "" CACHE STRING set(GPU_COMPUTE_VER "" CACHE STRING
"Space separated list of compute versions to be built against, e.g. '35 61'") "Space separated list of compute versions to be built against, e.g. '35 61'")
set(NVTX_HEADER_DIR "" CACHE PATH
"Path to the stand-alone nvtx header")
## Bindings ## Bindings
option(JVM_BINDINGS "Build JVM bindings" OFF) option(JVM_BINDINGS "Build JVM bindings" OFF)
@ -175,6 +178,11 @@ if(USE_CUDA AND (NOT GENERATE_COMPILATION_DATABASE))
add_definitions(-DXGBOOST_USE_NCCL) add_definitions(-DXGBOOST_USE_NCCL)
endif() endif()
if(USE_NVTX)
cuda_include_directories("${NVTX_HEADER_DIR}")
add_definitions(-DXGBOOST_USE_NVTX)
endif()
set(GENCODE_FLAGS "") set(GENCODE_FLAGS "")
format_gencode_flags("${GPU_COMPUTE_VER}" GENCODE_FLAGS) format_gencode_flags("${GPU_COMPUTE_VER}" GENCODE_FLAGS)
message("cuda architecture flags: ${GENCODE_FLAGS}") message("cuda architecture flags: ${GENCODE_FLAGS}")
@ -190,6 +198,7 @@ if(USE_CUDA AND (NOT GENERATE_COMPILATION_DATABASE))
link_directories(${NCCL_LIBRARY}) link_directories(${NCCL_LIBRARY})
target_link_libraries(gpuxgboost ${NCCL_LIB_NAME}) target_link_libraries(gpuxgboost ${NCCL_LIB_NAME})
endif() endif()
list(APPEND LINK_LIBRARIES gpuxgboost) list(APPEND LINK_LIBRARIES gpuxgboost)
elseif (USE_CUDA AND GENERATE_COMPILATION_DATABASE) elseif (USE_CUDA AND GENERATE_COMPILATION_DATABASE)

View File

@ -195,6 +195,10 @@ Training time time on 1,000,000 rows x 50 columns with 500 boosting iterations a
See `GPU Accelerated XGBoost <https://xgboost.ai/2016/12/14/GPU-accelerated-xgboost.html>`_ and `Updates to the XGBoost GPU algorithms <https://xgboost.ai/2018/07/04/gpu-xgboost-update.html>`_ for additional performance benchmarks of the ``gpu_exact`` and ``gpu_hist`` tree methods. See `GPU Accelerated XGBoost <https://xgboost.ai/2016/12/14/GPU-accelerated-xgboost.html>`_ and `Updates to the XGBoost GPU algorithms <https://xgboost.ai/2018/07/04/gpu-xgboost-update.html>`_ for additional performance benchmarks of the ``gpu_exact`` and ``gpu_hist`` tree methods.
Developer notes
==========
The application may be profiled with annotations by specifying USE_NTVX to cmake and providing the path to the stand-alone nvtx header via NVTX_HEADER_DIR. Regions covered by the 'Monitor' class in cuda code will automatically appear in the nsight profiler.
********** **********
References References
********** **********

View File

@ -308,7 +308,7 @@ class DVec {
} }
safe_cuda(cudaSetDevice(this->DeviceIdx())); safe_cuda(cudaSetDevice(this->DeviceIdx()));
if (other.DeviceIdx() == this->DeviceIdx()) { if (other.DeviceIdx() == this->DeviceIdx()) {
dh::safe_cuda(cudaMemcpy(this->Data(), other.Data(), dh::safe_cuda(cudaMemcpyAsync(this->Data(), other.Data(),
other.Size() * sizeof(T), other.Size() * sizeof(T),
cudaMemcpyDeviceToDevice)); cudaMemcpyDeviceToDevice));
} else { } else {
@ -338,7 +338,7 @@ class DVec {
throw std::runtime_error( throw std::runtime_error(
"Cannot copy assign vector to dvec, sizes are different"); "Cannot copy assign vector to dvec, sizes are different");
} }
safe_cuda(cudaMemcpy(this->Data(), begin.get(), Size() * sizeof(T), safe_cuda(cudaMemcpyAsync(this->Data(), begin.get(), Size() * sizeof(T),
cudaMemcpyDefault)); cudaMemcpyDefault));
} }
}; };

View File

@ -290,14 +290,14 @@ struct GPUSketcher {
offset_vec[row_begin_ + batch_row_begin]; offset_vec[row_begin_ + batch_row_begin];
// copy the batch to the GPU // copy the batch to the GPU
dh::safe_cuda dh::safe_cuda
(cudaMemcpy(entries_.data().get(), (cudaMemcpyAsync(entries_.data().get(),
data_vec.data() + offset_vec[row_begin_ + batch_row_begin], data_vec.data() + offset_vec[row_begin_ + batch_row_begin],
n_entries * sizeof(Entry), cudaMemcpyDefault)); n_entries * sizeof(Entry), cudaMemcpyDefault));
// copy the weights if necessary // copy the weights if necessary
if (has_weights_) { if (has_weights_) {
const auto& weights_vec = info.weights_.HostVector(); const auto& weights_vec = info.weights_.HostVector();
dh::safe_cuda dh::safe_cuda
(cudaMemcpy(weights_.data().get(), (cudaMemcpyAsync(weights_.data().get(),
weights_vec.data() + row_begin_ + batch_row_begin, weights_vec.data() + row_begin_ + batch_row_begin,
batch_nrows * sizeof(bst_float), cudaMemcpyDefault)); batch_nrows * sizeof(bst_float), cudaMemcpyDefault));
} }
@ -315,15 +315,11 @@ struct GPUSketcher {
has_weights_ ? weights_.data().get() : nullptr, entries_.data().get(), has_weights_ ? weights_.data().get() : nullptr, entries_.data().get(),
gpu_batch_nrows_, num_cols_, gpu_batch_nrows_, num_cols_,
offset_vec[row_begin_ + batch_row_begin], batch_nrows); offset_vec[row_begin_ + batch_row_begin], batch_nrows);
dh::safe_cuda(cudaGetLastError()); // NOLINT
dh::safe_cuda(cudaDeviceSynchronize()); // NOLINT
for (int icol = 0; icol < num_cols_; ++icol) { for (int icol = 0; icol < num_cols_; ++icol) {
FindColumnCuts(batch_nrows, icol); FindColumnCuts(batch_nrows, icol);
} }
dh::safe_cuda(cudaDeviceSynchronize()); // NOLINT
// add cuts into sketches // add cuts into sketches
thrust::copy(cuts_d_.begin(), cuts_d_.end(), cuts_h_.begin()); thrust::copy(cuts_d_.begin(), cuts_d_.end(), cuts_h_.begin());
for (int icol = 0; icol < num_cols_; ++icol) { for (int icol = 0; icol < num_cols_; ++icol) {

View File

@ -74,14 +74,14 @@ struct HostDeviceVectorImpl {
// TODO(canonizer): avoid full copy of host data // TODO(canonizer): avoid full copy of host data
LazySyncDevice(GPUAccess::kWrite); LazySyncDevice(GPUAccess::kWrite);
SetDevice(); SetDevice();
dh::safe_cuda(cudaMemcpy(data_.data().get(), begin + start_, dh::safe_cuda(cudaMemcpyAsync(data_.data().get(), begin + start_,
data_.size() * sizeof(T), cudaMemcpyDefault)); data_.size() * sizeof(T), cudaMemcpyDefault));
} }
void GatherTo(thrust::device_ptr<T> begin) { void GatherTo(thrust::device_ptr<T> begin) {
LazySyncDevice(GPUAccess::kRead); LazySyncDevice(GPUAccess::kRead);
SetDevice(); SetDevice();
dh::safe_cuda(cudaMemcpy(begin.get() + start_, data_.data().get(), dh::safe_cuda(cudaMemcpyAsync(begin.get() + start_, data_.data().get(),
proper_size_ * sizeof(T), cudaMemcpyDefault)); proper_size_ * sizeof(T), cudaMemcpyDefault));
} }
@ -97,7 +97,7 @@ struct HostDeviceVectorImpl {
LazySyncDevice(GPUAccess::kWrite); LazySyncDevice(GPUAccess::kWrite);
other->LazySyncDevice(GPUAccess::kRead); other->LazySyncDevice(GPUAccess::kRead);
SetDevice(); SetDevice();
dh::safe_cuda(cudaMemcpy(data_.data().get(), other->data_.data().get(), dh::safe_cuda(cudaMemcpyAsync(data_.data().get(), other->data_.data().get(),
data_.size() * sizeof(T), cudaMemcpyDefault)); data_.size() * sizeof(T), cudaMemcpyDefault));
} }

View File

@ -8,7 +8,9 @@
#include <map> #include <map>
#include <string> #include <string>
#include "common.h" #if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
#include <nvToolsExt.h>
#endif
namespace xgboost { namespace xgboost {
namespace common { namespace common {
@ -45,9 +47,11 @@ struct Timer {
*/ */
struct Monitor { struct Monitor {
private:
struct Statistics { struct Statistics {
Timer timer; Timer timer;
size_t count{0}; size_t count{0};
uint64_t nvtx_id;
}; };
std::string label = ""; std::string label = "";
std::map<std::string, Statistics> statistics_map; std::map<std::string, Statistics> statistics_map;
@ -75,35 +79,37 @@ struct Monitor {
} }
self_timer.Stop(); self_timer.Stop();
} }
void Init(std::string label) { void Init(std::string label) { this->label = label; }
this->label = label; void Start(const std::string &name) {
}
void Start(const std::string &name) { statistics_map[name].timer.Start(); }
void Start(const std::string &name, GPUSet devices) {
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) { if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
#ifdef __CUDACC__ statistics_map[name].timer.Start();
for (auto device : devices) {
cudaSetDevice(device);
cudaDeviceSynchronize();
}
#endif // __CUDACC__
} }
statistics_map[name].timer.Start();
} }
void Stop(const std::string &name) { void Stop(const std::string &name) {
statistics_map[name].timer.Stop();
statistics_map[name].count++;
}
void Stop(const std::string &name, GPUSet devices) {
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) { if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
#ifdef __CUDACC__ auto &stats = statistics_map[name];
for (auto device : devices) { stats.timer.Stop();
cudaSetDevice(device); stats.count++;
cudaDeviceSynchronize(); }
} }
#endif // __CUDACC__ void StartCuda(const std::string &name) {
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
auto &stats = statistics_map[name];
stats.timer.Start();
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
stats.nvtx_id = nvtxRangeStartA(name.c_str());
#endif
}
}
void StopCuda(const std::string &name) {
if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) {
auto &stats = statistics_map[name];
stats.timer.Stop();
stats.count++;
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
nvtxRangeEnd(stats.nvtx_id);
#endif
} }
this->Stop(name);
} }
}; };
} // namespace common } // namespace common

View File

@ -145,8 +145,6 @@ class Transform {
static_cast<int>(dh::DivRoundUp(*(range_.end()), kBlockThreads)); static_cast<int>(dh::DivRoundUp(*(range_.end()), kBlockThreads));
detail::LaunchCUDAKernel<<<GRID_SIZE, kBlockThreads>>>( detail::LaunchCUDAKernel<<<GRID_SIZE, kBlockThreads>>>(
_func, shard_range, UnpackHDV(_vectors, device)...); _func, shard_range, UnpackHDV(_vectors, device)...);
dh::safe_cuda(cudaGetLastError());
dh::safe_cuda(cudaDeviceSynchronize());
} }
} }
#else #else

View File

@ -252,17 +252,17 @@ class GPUPredictor : public xgboost::Predictor {
size_t tree_begin, size_t tree_end) { size_t tree_begin, size_t tree_end) {
dh::safe_cuda(cudaSetDevice(device_)); dh::safe_cuda(cudaSetDevice(device_));
nodes.resize(h_nodes.size()); nodes.resize(h_nodes.size());
dh::safe_cuda(cudaMemcpy(dh::Raw(nodes), h_nodes.data(), dh::safe_cuda(cudaMemcpyAsync(dh::Raw(nodes), h_nodes.data(),
sizeof(DevicePredictionNode) * h_nodes.size(), sizeof(DevicePredictionNode) * h_nodes.size(),
cudaMemcpyHostToDevice)); cudaMemcpyHostToDevice));
tree_segments.resize(h_tree_segments.size()); tree_segments.resize(h_tree_segments.size());
dh::safe_cuda(cudaMemcpy(dh::Raw(tree_segments), h_tree_segments.data(), dh::safe_cuda(cudaMemcpyAsync(dh::Raw(tree_segments), h_tree_segments.data(),
sizeof(size_t) * h_tree_segments.size(), sizeof(size_t) * h_tree_segments.size(),
cudaMemcpyHostToDevice)); cudaMemcpyHostToDevice));
tree_group.resize(model.tree_info.size()); tree_group.resize(model.tree_info.size());
dh::safe_cuda(cudaMemcpy(dh::Raw(tree_group), model.tree_info.data(), dh::safe_cuda(cudaMemcpyAsync(dh::Raw(tree_group), model.tree_info.data(),
sizeof(int) * model.tree_info.size(), sizeof(int) * model.tree_info.size(),
cudaMemcpyHostToDevice)); cudaMemcpyHostToDevice));
@ -288,9 +288,6 @@ class GPUPredictor : public xgboost::Predictor {
dh::ToSpan(tree_group), batch.offset.DeviceSpan(device_), dh::ToSpan(tree_group), batch.offset.DeviceSpan(device_),
batch.data.DeviceSpan(device_), tree_begin, tree_end, info.num_col_, batch.data.DeviceSpan(device_), tree_begin, tree_end, info.num_col_,
num_rows, entry_start, use_shared, model.param.num_output_group); num_rows, entry_start, use_shared, model.param.num_output_group);
dh::safe_cuda(cudaGetLastError());
dh::safe_cuda(cudaDeviceSynchronize());
} }
int device_; int device_;

View File

@ -289,7 +289,9 @@ struct DeviceHistogram {
void Reset() { void Reset() {
dh::safe_cuda(cudaSetDevice(device_id_)); dh::safe_cuda(cudaSetDevice(device_id_));
data.resize(0); dh::safe_cuda(cudaMemsetAsync(
data.data().get(), 0,
data.size() * sizeof(typename decltype(data)::value_type)));
nidx_map.clear(); nidx_map.clear();
} }
@ -299,20 +301,25 @@ struct DeviceHistogram {
void AllocateHistogram(int nidx) { void AllocateHistogram(int nidx) {
if (HistogramExists(nidx)) return; if (HistogramExists(nidx)) return;
size_t current_size =
nidx_map.size() * n_bins * 2; // Number of items currently used in data
dh::safe_cuda(cudaSetDevice(device_id_)); dh::safe_cuda(cudaSetDevice(device_id_));
if (data.size() > kStopGrowingSize) { if (data.size() >= kStopGrowingSize) {
// Recycle histogram memory // Recycle histogram memory
std::pair<int, size_t> old_entry = *nidx_map.begin(); std::pair<int, size_t> old_entry = *nidx_map.begin();
nidx_map.erase(old_entry.first); nidx_map.erase(old_entry.first);
dh::safe_cuda(cudaMemset(data.data().get() + old_entry.second, 0, dh::safe_cuda(cudaMemsetAsync(data.data().get() + old_entry.second, 0,
n_bins * sizeof(GradientSumT))); n_bins * sizeof(GradientSumT)));
nidx_map[nidx] = old_entry.second; nidx_map[nidx] = old_entry.second;
} else { } else {
// Append new node histogram // Append new node histogram
nidx_map[nidx] = data.size(); nidx_map[nidx] = current_size;
// x 2: Hess and Grad. if (data.size() < current_size + n_bins * 2) {
data.resize(data.size() + (n_bins * 2)); size_t new_size = current_size * 2; // Double in size
new_size = std::max(static_cast<size_t>(n_bins * 2),
new_size); // Have at least one histogram
data.resize(new_size);
}
} }
} }
@ -610,20 +617,20 @@ struct DeviceShard {
feature_set_d.resize(feature_set.size()); feature_set_d.resize(feature_set.size());
auto d_features = common::Span<int>(feature_set_d.data().get(), auto d_features = common::Span<int>(feature_set_d.data().get(),
feature_set_d.size()); feature_set_d.size());
dh::safe_cuda(cudaMemcpy(d_features.data(), feature_set.data(), dh::safe_cuda(cudaMemcpyAsync(d_features.data(), feature_set.data(),
d_features.size_bytes(), cudaMemcpyDefault)); d_features.size_bytes(), cudaMemcpyDefault));
DeviceNodeStats node(node_sum_gradients[nidx], nidx, param); DeviceNodeStats node(node_sum_gradients[nidx], nidx, param);
// One block for each feature // One block for each feature
int constexpr BLOCK_THREADS = 256; int constexpr BLOCK_THREADS = 256;
EvaluateSplitKernel<BLOCK_THREADS, GradientSumT> EvaluateSplitKernel<BLOCK_THREADS, GradientSumT>
<<<uint32_t(feature_set.size()), BLOCK_THREADS, 0>>> <<<uint32_t(feature_set.size()), BLOCK_THREADS, 0>>>(
(hist.GetNodeHistogram(nidx), d_features, node, hist.GetNodeHistogram(nidx), d_features, node,
cut_.feature_segments.GetSpan(), cut_.min_fvalue.GetSpan(), cut_.feature_segments.GetSpan(), cut_.min_fvalue.GetSpan(),
cut_.gidx_fvalue_map.GetSpan(), GPUTrainingParam(param), cut_.gidx_fvalue_map.GetSpan(), GPUTrainingParam(param),
d_split_candidates, value_constraint, monotone_constraints.GetSpan()); d_split_candidates, value_constraint,
monotone_constraints.GetSpan());
dh::safe_cuda(cudaDeviceSynchronize());
std::vector<DeviceSplitCandidate> split_candidates(feature_set.size()); std::vector<DeviceSplitCandidate> split_candidates(feature_set.size());
dh::safe_cuda(cudaMemcpy(split_candidates.data(), d_split_candidates.data(), dh::safe_cuda(cudaMemcpy(split_candidates.data(), d_split_candidates.data(),
split_candidates.size() * sizeof(DeviceSplitCandidate), split_candidates.size() * sizeof(DeviceSplitCandidate),
@ -725,20 +732,21 @@ struct DeviceShard {
common::Span<bst_uint>(ridx.Current() + segment.begin, segment.Size()), common::Span<bst_uint>(ridx.Current() + segment.begin, segment.Size()),
common::Span<bst_uint>(ridx.other() + segment.begin, segment.Size()), common::Span<bst_uint>(ridx.other() + segment.begin, segment.Size()),
left_nidx, right_nidx, left_count); left_nidx, right_nidx, left_count);
// Copy back key // Copy back key/value
dh::safe_cuda(cudaMemcpy( const auto d_position_current = position.Current() + segment.begin;
position.Current() + segment.begin, position.other() + segment.begin, const auto d_position_other = position.other() + segment.begin;
segment.Size() * sizeof(int), cudaMemcpyDeviceToDevice)); const auto d_ridx_current = ridx.Current() + segment.begin;
// Copy back value const auto d_ridx_other = ridx.other() + segment.begin;
dh::safe_cuda(cudaMemcpy( dh::LaunchN(device_id_, segment.Size(), [=] __device__(size_t idx) {
ridx.Current() + segment.begin, ridx.other() + segment.begin, d_position_current[idx] = d_position_other[idx];
segment.Size() * sizeof(bst_uint), cudaMemcpyDeviceToDevice)); d_ridx_current[idx] = d_ridx_other[idx];
});
} }
void UpdatePredictionCache(bst_float* out_preds_d) { void UpdatePredictionCache(bst_float* out_preds_d) {
dh::safe_cuda(cudaSetDevice(device_id_)); dh::safe_cuda(cudaSetDevice(device_id_));
if (!prediction_cache_initialised) { if (!prediction_cache_initialised) {
dh::safe_cuda(cudaMemcpy( dh::safe_cuda(cudaMemcpyAsync(
prediction_cache.Data(), out_preds_d, prediction_cache.Data(), out_preds_d,
prediction_cache.Size() * sizeof(bst_float), cudaMemcpyDefault)); prediction_cache.Size() * sizeof(bst_float), cudaMemcpyDefault));
} }
@ -746,7 +754,7 @@ struct DeviceShard {
CalcWeightTrainParam param_d(param); CalcWeightTrainParam param_d(param);
dh::safe_cuda(cudaMemcpy(node_sum_gradients_d.Data(), dh::safe_cuda(cudaMemcpyAsync(node_sum_gradients_d.Data(),
node_sum_gradients.data(), node_sum_gradients.data(),
sizeof(GradientPair) * node_sum_gradients.size(), sizeof(GradientPair) * node_sum_gradients.size(),
cudaMemcpyHostToDevice)); cudaMemcpyHostToDevice));
@ -925,9 +933,6 @@ inline void DeviceShard<GradientSumT>::CreateHistIndices(const SparsePage& row_b
batch_row_begin, batch_nrows, batch_row_begin, batch_nrows,
row_ptrs[batch_row_begin], row_ptrs[batch_row_begin],
row_stride, null_gidx_value); row_stride, null_gidx_value);
dh::safe_cuda(cudaGetLastError());
dh::safe_cuda(cudaDeviceSynchronize());
} }
// free the memory that is no longer needed // free the memory that is no longer needed
@ -965,7 +970,7 @@ class GPUHistMakerSpecialised{
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat, void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) { const std::vector<RegTree*>& trees) {
monitor_.Start("Update", dist_.Devices()); monitor_.StartCuda("Update");
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
float lr = param_.learning_rate; float lr = param_.learning_rate;
param_.learning_rate = lr / trees.size(); param_.learning_rate = lr / trees.size();
@ -980,7 +985,7 @@ class GPUHistMakerSpecialised{
LOG(FATAL) << "Exception in gpu_hist: " << e.what() << std::endl; LOG(FATAL) << "Exception in gpu_hist: " << e.what() << std::endl;
} }
param_.learning_rate = lr; param_.learning_rate = lr;
monitor_.Stop("Update", dist_.Devices()); monitor_.StopCuda("Update");
} }
void InitDataOnce(DMatrix* dmat) { void InitDataOnce(DMatrix* dmat) {
@ -1010,17 +1015,17 @@ class GPUHistMakerSpecialised{
}); });
// Find the cuts. // Find the cuts.
monitor_.Start("Quantiles", dist_.Devices()); monitor_.StartCuda("Quantiles");
common::DeviceSketch(batch, *info_, param_, &hmat_, hist_maker_param_.gpu_batch_nrows); common::DeviceSketch(batch, *info_, param_, &hmat_, hist_maker_param_.gpu_batch_nrows);
n_bins_ = hmat_.row_ptr.back(); n_bins_ = hmat_.row_ptr.back();
monitor_.Stop("Quantiles", dist_.Devices()); monitor_.StopCuda("Quantiles");
monitor_.Start("BinningCompression", dist_.Devices()); monitor_.StartCuda("BinningCompression");
dh::ExecuteIndexShards(&shards_, [&](int idx, dh::ExecuteIndexShards(&shards_, [&](int idx,
std::unique_ptr<DeviceShard<GradientSumT>>& shard) { std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
shard->InitCompressedData(hmat_, batch); shard->InitCompressedData(hmat_, batch);
}); });
monitor_.Stop("BinningCompression", dist_.Devices()); monitor_.StopCuda("BinningCompression");
++batch_iter; ++batch_iter;
CHECK(batch_iter.AtEnd()) << "External memory not supported"; CHECK(batch_iter.AtEnd()) << "External memory not supported";
@ -1030,16 +1035,16 @@ class GPUHistMakerSpecialised{
void InitData(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat) { void InitData(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat) {
if (!initialised_) { if (!initialised_) {
monitor_.Start("InitDataOnce", dist_.Devices()); monitor_.StartCuda("InitDataOnce");
this->InitDataOnce(dmat); this->InitDataOnce(dmat);
monitor_.Stop("InitDataOnce", dist_.Devices()); monitor_.StopCuda("InitDataOnce");
} }
column_sampler_.Init(info_->num_col_, param_.colsample_bynode, column_sampler_.Init(info_->num_col_, param_.colsample_bynode,
param_.colsample_bylevel, param_.colsample_bytree); param_.colsample_bylevel, param_.colsample_bytree);
// Copy gpair & reset memory // Copy gpair & reset memory
monitor_.Start("InitDataReset", dist_.Devices()); monitor_.StartCuda("InitDataReset");
gpair->Reshard(dist_); gpair->Reshard(dist_);
dh::ExecuteIndexShards( dh::ExecuteIndexShards(
@ -1047,13 +1052,12 @@ class GPUHistMakerSpecialised{
[&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) { [&](int idx, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
shard->Reset(gpair); shard->Reset(gpair);
}); });
monitor_.Stop("InitDataReset", dist_.Devices()); monitor_.StopCuda("InitDataReset");
} }
void AllReduceHist(int nidx) { void AllReduceHist(int nidx) {
if (shards_.size() == 1 && !rabit::IsDistributed()) if (shards_.size() == 1 && !rabit::IsDistributed()) return;
return; monitor_.StartCuda("AllReduce");
monitor_.Start("AllReduce");
reducer_.GroupStart(); reducer_.GroupStart();
for (auto& shard : shards_) { for (auto& shard : shards_) {
@ -1067,7 +1071,7 @@ class GPUHistMakerSpecialised{
reducer_.GroupEnd(); reducer_.GroupEnd();
reducer_.Synchronize(); reducer_.Synchronize();
monitor_.Stop("AllReduce"); monitor_.StopCuda("AllReduce");
} }
/** /**
@ -1250,12 +1254,12 @@ class GPUHistMakerSpecialised{
RegTree* p_tree) { RegTree* p_tree) {
auto& tree = *p_tree; auto& tree = *p_tree;
monitor_.Start("InitData", dist_.Devices()); monitor_.StartCuda("InitData");
this->InitData(gpair, p_fmat); this->InitData(gpair, p_fmat);
monitor_.Stop("InitData", dist_.Devices()); monitor_.StopCuda("InitData");
monitor_.Start("InitRoot", dist_.Devices()); monitor_.StartCuda("InitRoot");
this->InitRoot(p_tree); this->InitRoot(p_tree);
monitor_.Stop("InitRoot", dist_.Devices()); monitor_.StopCuda("InitRoot");
auto timestamp = qexpand_->size(); auto timestamp = qexpand_->size();
auto num_leaves = 1; auto num_leaves = 1;
@ -1266,9 +1270,9 @@ class GPUHistMakerSpecialised{
if (!candidate.IsValid(param_, num_leaves)) continue; if (!candidate.IsValid(param_, num_leaves)) continue;
this->ApplySplit(candidate, p_tree); this->ApplySplit(candidate, p_tree);
monitor_.Start("UpdatePosition", dist_.Devices()); monitor_.StartCuda("UpdatePosition");
this->UpdatePosition(candidate, p_tree); this->UpdatePosition(candidate, p_tree);
monitor_.Stop("UpdatePosition", dist_.Devices()); monitor_.StopCuda("UpdatePosition");
num_leaves++; num_leaves++;
int left_child_nidx = tree[candidate.nid].LeftChild(); int left_child_nidx = tree[candidate.nid].LeftChild();
@ -1277,32 +1281,30 @@ class GPUHistMakerSpecialised{
// Only create child entries if needed // Only create child entries if needed
if (ExpandEntry::ChildIsValid(param_, tree.GetDepth(left_child_nidx), if (ExpandEntry::ChildIsValid(param_, tree.GetDepth(left_child_nidx),
num_leaves)) { num_leaves)) {
monitor_.Start("BuildHist", dist_.Devices()); monitor_.StartCuda("BuildHist");
this->BuildHistLeftRight(candidate.nid, left_child_nidx, this->BuildHistLeftRight(candidate.nid, left_child_nidx,
right_child_nidx); right_child_nidx);
monitor_.Stop("BuildHist", dist_.Devices()); monitor_.StopCuda("BuildHist");
monitor_.Start("EvaluateSplits", dist_.Devices()); monitor_.StartCuda("EvaluateSplits");
auto left_child_split = auto left_child_split = this->EvaluateSplit(left_child_nidx, p_tree);
this->EvaluateSplit(left_child_nidx, p_tree); auto right_child_split = this->EvaluateSplit(right_child_nidx, p_tree);
auto right_child_split =
this->EvaluateSplit(right_child_nidx, p_tree);
qexpand_->push(ExpandEntry(left_child_nidx, qexpand_->push(ExpandEntry(left_child_nidx,
tree.GetDepth(left_child_nidx), left_child_split, tree.GetDepth(left_child_nidx),
timestamp++)); left_child_split, timestamp++));
qexpand_->push(ExpandEntry(right_child_nidx, qexpand_->push(ExpandEntry(right_child_nidx,
tree.GetDepth(right_child_nidx), right_child_split, tree.GetDepth(right_child_nidx),
timestamp++)); right_child_split, timestamp++));
monitor_.Stop("EvaluateSplits", dist_.Devices()); monitor_.StopCuda("EvaluateSplits");
} }
} }
} }
bool UpdatePredictionCache( bool UpdatePredictionCache(
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) { const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) {
monitor_.Start("UpdatePredictionCache", dist_.Devices());
if (shards_.empty() || p_last_fmat_ == nullptr || p_last_fmat_ != data) if (shards_.empty() || p_last_fmat_ == nullptr || p_last_fmat_ != data)
return false; return false;
monitor_.StartCuda("UpdatePredictionCache");
p_out_preds->Reshard(dist_.Devices()); p_out_preds->Reshard(dist_.Devices());
dh::ExecuteIndexShards( dh::ExecuteIndexShards(
&shards_, &shards_,
@ -1310,7 +1312,7 @@ class GPUHistMakerSpecialised{
shard->UpdatePredictionCache( shard->UpdatePredictionCache(
p_out_preds->DevicePointer(shard->device_id_)); p_out_preds->DevicePointer(shard->device_id_));
}); });
monitor_.Stop("UpdatePredictionCache", dist_.Devices()); monitor_.StopCuda("UpdatePredictionCache");
return true; return true;
} }