diff --git a/CMakeLists.txt b/CMakeLists.txt index 4a7ebdf25..19ba17612 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,9 +10,12 @@ msvc_use_static_runtime() # Options ## GPUs option(USE_CUDA "Build with GPU acceleration" OFF) +option(USE_NVTX "Build with cuda profiling annotations. Developers only." OFF) option(USE_NCCL "Build with multiple GPUs support" OFF) set(GPU_COMPUTE_VER "" CACHE STRING "Space separated list of compute versions to be built against, e.g. '35 61'") +set(NVTX_HEADER_DIR "" CACHE PATH + "Path to the stand-alone nvtx header") ## Bindings option(JVM_BINDINGS "Build JVM bindings" OFF) @@ -175,6 +178,11 @@ if(USE_CUDA AND (NOT GENERATE_COMPILATION_DATABASE)) add_definitions(-DXGBOOST_USE_NCCL) endif() + if(USE_NVTX) + cuda_include_directories("${NVTX_HEADER_DIR}") + add_definitions(-DXGBOOST_USE_NVTX) + endif() + set(GENCODE_FLAGS "") format_gencode_flags("${GPU_COMPUTE_VER}" GENCODE_FLAGS) message("cuda architecture flags: ${GENCODE_FLAGS}") @@ -190,6 +198,7 @@ if(USE_CUDA AND (NOT GENERATE_COMPILATION_DATABASE)) link_directories(${NCCL_LIBRARY}) target_link_libraries(gpuxgboost ${NCCL_LIB_NAME}) endif() + list(APPEND LINK_LIBRARIES gpuxgboost) elseif (USE_CUDA AND GENERATE_COMPILATION_DATABASE) diff --git a/doc/gpu/index.rst b/doc/gpu/index.rst index 071d66c23..09ccde60c 100644 --- a/doc/gpu/index.rst +++ b/doc/gpu/index.rst @@ -195,6 +195,10 @@ Training time time on 1,000,000 rows x 50 columns with 500 boosting iterations a See `GPU Accelerated XGBoost `_ and `Updates to the XGBoost GPU algorithms `_ for additional performance benchmarks of the ``gpu_exact`` and ``gpu_hist`` tree methods. +Developer notes +========== +The application may be profiled with annotations by specifying USE_NTVX to cmake and providing the path to the stand-alone nvtx header via NVTX_HEADER_DIR. Regions covered by the 'Monitor' class in cuda code will automatically appear in the nsight profiler. + ********** References ********** diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 4e28daccb..923ee03cd 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -308,7 +308,7 @@ class DVec { } safe_cuda(cudaSetDevice(this->DeviceIdx())); if (other.DeviceIdx() == this->DeviceIdx()) { - dh::safe_cuda(cudaMemcpy(this->Data(), other.Data(), + dh::safe_cuda(cudaMemcpyAsync(this->Data(), other.Data(), other.Size() * sizeof(T), cudaMemcpyDeviceToDevice)); } else { @@ -338,7 +338,7 @@ class DVec { throw std::runtime_error( "Cannot copy assign vector to dvec, sizes are different"); } - safe_cuda(cudaMemcpy(this->Data(), begin.get(), Size() * sizeof(T), + safe_cuda(cudaMemcpyAsync(this->Data(), begin.get(), Size() * sizeof(T), cudaMemcpyDefault)); } }; diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 0f2f89b95..0c41bcb67 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -290,14 +290,14 @@ struct GPUSketcher { offset_vec[row_begin_ + batch_row_begin]; // copy the batch to the GPU dh::safe_cuda - (cudaMemcpy(entries_.data().get(), + (cudaMemcpyAsync(entries_.data().get(), data_vec.data() + offset_vec[row_begin_ + batch_row_begin], n_entries * sizeof(Entry), cudaMemcpyDefault)); // copy the weights if necessary if (has_weights_) { const auto& weights_vec = info.weights_.HostVector(); dh::safe_cuda - (cudaMemcpy(weights_.data().get(), + (cudaMemcpyAsync(weights_.data().get(), weights_vec.data() + row_begin_ + batch_row_begin, batch_nrows * sizeof(bst_float), cudaMemcpyDefault)); } @@ -315,15 +315,11 @@ struct GPUSketcher { has_weights_ ? weights_.data().get() : nullptr, entries_.data().get(), gpu_batch_nrows_, num_cols_, offset_vec[row_begin_ + batch_row_begin], batch_nrows); - dh::safe_cuda(cudaGetLastError()); // NOLINT - dh::safe_cuda(cudaDeviceSynchronize()); // NOLINT for (int icol = 0; icol < num_cols_; ++icol) { FindColumnCuts(batch_nrows, icol); } - dh::safe_cuda(cudaDeviceSynchronize()); // NOLINT - // add cuts into sketches thrust::copy(cuts_d_.begin(), cuts_d_.end(), cuts_h_.begin()); for (int icol = 0; icol < num_cols_; ++icol) { diff --git a/src/common/host_device_vector.cu b/src/common/host_device_vector.cu index 6e33adfcb..df8541f09 100644 --- a/src/common/host_device_vector.cu +++ b/src/common/host_device_vector.cu @@ -74,14 +74,14 @@ struct HostDeviceVectorImpl { // TODO(canonizer): avoid full copy of host data LazySyncDevice(GPUAccess::kWrite); SetDevice(); - dh::safe_cuda(cudaMemcpy(data_.data().get(), begin + start_, + dh::safe_cuda(cudaMemcpyAsync(data_.data().get(), begin + start_, data_.size() * sizeof(T), cudaMemcpyDefault)); } void GatherTo(thrust::device_ptr begin) { LazySyncDevice(GPUAccess::kRead); SetDevice(); - dh::safe_cuda(cudaMemcpy(begin.get() + start_, data_.data().get(), + dh::safe_cuda(cudaMemcpyAsync(begin.get() + start_, data_.data().get(), proper_size_ * sizeof(T), cudaMemcpyDefault)); } @@ -97,7 +97,7 @@ struct HostDeviceVectorImpl { LazySyncDevice(GPUAccess::kWrite); other->LazySyncDevice(GPUAccess::kRead); SetDevice(); - dh::safe_cuda(cudaMemcpy(data_.data().get(), other->data_.data().get(), + dh::safe_cuda(cudaMemcpyAsync(data_.data().get(), other->data_.data().get(), data_.size() * sizeof(T), cudaMemcpyDefault)); } diff --git a/src/common/timer.h b/src/common/timer.h index a4b631783..c9aa9de18 100644 --- a/src/common/timer.h +++ b/src/common/timer.h @@ -8,7 +8,9 @@ #include #include -#include "common.h" +#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__) +#include +#endif namespace xgboost { namespace common { @@ -45,9 +47,11 @@ struct Timer { */ struct Monitor { + private: struct Statistics { Timer timer; size_t count{0}; + uint64_t nvtx_id; }; std::string label = ""; std::map statistics_map; @@ -75,35 +79,37 @@ struct Monitor { } self_timer.Stop(); } - void Init(std::string label) { - this->label = label; - } - void Start(const std::string &name) { statistics_map[name].timer.Start(); } - void Start(const std::string &name, GPUSet devices) { + void Init(std::string label) { this->label = label; } + void Start(const std::string &name) { if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) { -#ifdef __CUDACC__ - for (auto device : devices) { - cudaSetDevice(device); - cudaDeviceSynchronize(); - } -#endif // __CUDACC__ + statistics_map[name].timer.Start(); } - statistics_map[name].timer.Start(); } void Stop(const std::string &name) { - statistics_map[name].timer.Stop(); - statistics_map[name].count++; - } - void Stop(const std::string &name, GPUSet devices) { if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) { -#ifdef __CUDACC__ - for (auto device : devices) { - cudaSetDevice(device); - cudaDeviceSynchronize(); - } -#endif // __CUDACC__ + auto &stats = statistics_map[name]; + stats.timer.Stop(); + stats.count++; + } + } + void StartCuda(const std::string &name) { + if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) { + auto &stats = statistics_map[name]; + stats.timer.Start(); +#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__) + stats.nvtx_id = nvtxRangeStartA(name.c_str()); +#endif + } + } + void StopCuda(const std::string &name) { + if (ConsoleLogger::ShouldLog(ConsoleLogger::LV::kDebug)) { + auto &stats = statistics_map[name]; + stats.timer.Stop(); + stats.count++; +#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__) + nvtxRangeEnd(stats.nvtx_id); +#endif } - this->Stop(name); } }; } // namespace common diff --git a/src/common/transform.h b/src/common/transform.h index c5382b4d7..6a8f83c83 100644 --- a/src/common/transform.h +++ b/src/common/transform.h @@ -145,8 +145,6 @@ class Transform { static_cast(dh::DivRoundUp(*(range_.end()), kBlockThreads)); detail::LaunchCUDAKernel<<>>( _func, shard_range, UnpackHDV(_vectors, device)...); - dh::safe_cuda(cudaGetLastError()); - dh::safe_cuda(cudaDeviceSynchronize()); } } #else diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 750282a9e..3d6a1df64 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -252,17 +252,17 @@ class GPUPredictor : public xgboost::Predictor { size_t tree_begin, size_t tree_end) { dh::safe_cuda(cudaSetDevice(device_)); nodes.resize(h_nodes.size()); - dh::safe_cuda(cudaMemcpy(dh::Raw(nodes), h_nodes.data(), + dh::safe_cuda(cudaMemcpyAsync(dh::Raw(nodes), h_nodes.data(), sizeof(DevicePredictionNode) * h_nodes.size(), cudaMemcpyHostToDevice)); tree_segments.resize(h_tree_segments.size()); - dh::safe_cuda(cudaMemcpy(dh::Raw(tree_segments), h_tree_segments.data(), + dh::safe_cuda(cudaMemcpyAsync(dh::Raw(tree_segments), h_tree_segments.data(), sizeof(size_t) * h_tree_segments.size(), cudaMemcpyHostToDevice)); tree_group.resize(model.tree_info.size()); - dh::safe_cuda(cudaMemcpy(dh::Raw(tree_group), model.tree_info.data(), + dh::safe_cuda(cudaMemcpyAsync(dh::Raw(tree_group), model.tree_info.data(), sizeof(int) * model.tree_info.size(), cudaMemcpyHostToDevice)); @@ -288,9 +288,6 @@ class GPUPredictor : public xgboost::Predictor { dh::ToSpan(tree_group), batch.offset.DeviceSpan(device_), batch.data.DeviceSpan(device_), tree_begin, tree_end, info.num_col_, num_rows, entry_start, use_shared, model.param.num_output_group); - - dh::safe_cuda(cudaGetLastError()); - dh::safe_cuda(cudaDeviceSynchronize()); } int device_; diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index f178b1f2b..c2511d378 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -289,7 +289,9 @@ struct DeviceHistogram { void Reset() { dh::safe_cuda(cudaSetDevice(device_id_)); - data.resize(0); + dh::safe_cuda(cudaMemsetAsync( + data.data().get(), 0, + data.size() * sizeof(typename decltype(data)::value_type))); nidx_map.clear(); } @@ -299,20 +301,25 @@ struct DeviceHistogram { void AllocateHistogram(int nidx) { if (HistogramExists(nidx)) return; - + size_t current_size = + nidx_map.size() * n_bins * 2; // Number of items currently used in data dh::safe_cuda(cudaSetDevice(device_id_)); - if (data.size() > kStopGrowingSize) { + if (data.size() >= kStopGrowingSize) { // Recycle histogram memory std::pair old_entry = *nidx_map.begin(); nidx_map.erase(old_entry.first); - dh::safe_cuda(cudaMemset(data.data().get() + old_entry.second, 0, + dh::safe_cuda(cudaMemsetAsync(data.data().get() + old_entry.second, 0, n_bins * sizeof(GradientSumT))); nidx_map[nidx] = old_entry.second; } else { // Append new node histogram - nidx_map[nidx] = data.size(); - // x 2: Hess and Grad. - data.resize(data.size() + (n_bins * 2)); + nidx_map[nidx] = current_size; + if (data.size() < current_size + n_bins * 2) { + size_t new_size = current_size * 2; // Double in size + new_size = std::max(static_cast(n_bins * 2), + new_size); // Have at least one histogram + data.resize(new_size); + } } } @@ -610,20 +617,20 @@ struct DeviceShard { feature_set_d.resize(feature_set.size()); auto d_features = common::Span(feature_set_d.data().get(), feature_set_d.size()); - dh::safe_cuda(cudaMemcpy(d_features.data(), feature_set.data(), + dh::safe_cuda(cudaMemcpyAsync(d_features.data(), feature_set.data(), d_features.size_bytes(), cudaMemcpyDefault)); DeviceNodeStats node(node_sum_gradients[nidx], nidx, param); // One block for each feature int constexpr BLOCK_THREADS = 256; EvaluateSplitKernel - <<>> - (hist.GetNodeHistogram(nidx), d_features, node, - cut_.feature_segments.GetSpan(), cut_.min_fvalue.GetSpan(), - cut_.gidx_fvalue_map.GetSpan(), GPUTrainingParam(param), - d_split_candidates, value_constraint, monotone_constraints.GetSpan()); + <<>>( + hist.GetNodeHistogram(nidx), d_features, node, + cut_.feature_segments.GetSpan(), cut_.min_fvalue.GetSpan(), + cut_.gidx_fvalue_map.GetSpan(), GPUTrainingParam(param), + d_split_candidates, value_constraint, + monotone_constraints.GetSpan()); - dh::safe_cuda(cudaDeviceSynchronize()); std::vector split_candidates(feature_set.size()); dh::safe_cuda(cudaMemcpy(split_candidates.data(), d_split_candidates.data(), split_candidates.size() * sizeof(DeviceSplitCandidate), @@ -725,20 +732,21 @@ struct DeviceShard { common::Span(ridx.Current() + segment.begin, segment.Size()), common::Span(ridx.other() + segment.begin, segment.Size()), left_nidx, right_nidx, left_count); - // Copy back key - dh::safe_cuda(cudaMemcpy( - position.Current() + segment.begin, position.other() + segment.begin, - segment.Size() * sizeof(int), cudaMemcpyDeviceToDevice)); - // Copy back value - dh::safe_cuda(cudaMemcpy( - ridx.Current() + segment.begin, ridx.other() + segment.begin, - segment.Size() * sizeof(bst_uint), cudaMemcpyDeviceToDevice)); + // Copy back key/value + const auto d_position_current = position.Current() + segment.begin; + const auto d_position_other = position.other() + segment.begin; + const auto d_ridx_current = ridx.Current() + segment.begin; + const auto d_ridx_other = ridx.other() + segment.begin; + dh::LaunchN(device_id_, segment.Size(), [=] __device__(size_t idx) { + d_position_current[idx] = d_position_other[idx]; + d_ridx_current[idx] = d_ridx_other[idx]; + }); } void UpdatePredictionCache(bst_float* out_preds_d) { dh::safe_cuda(cudaSetDevice(device_id_)); if (!prediction_cache_initialised) { - dh::safe_cuda(cudaMemcpy( + dh::safe_cuda(cudaMemcpyAsync( prediction_cache.Data(), out_preds_d, prediction_cache.Size() * sizeof(bst_float), cudaMemcpyDefault)); } @@ -746,7 +754,7 @@ struct DeviceShard { CalcWeightTrainParam param_d(param); - dh::safe_cuda(cudaMemcpy(node_sum_gradients_d.Data(), + dh::safe_cuda(cudaMemcpyAsync(node_sum_gradients_d.Data(), node_sum_gradients.data(), sizeof(GradientPair) * node_sum_gradients.size(), cudaMemcpyHostToDevice)); @@ -925,9 +933,6 @@ inline void DeviceShard::CreateHistIndices(const SparsePage& row_b batch_row_begin, batch_nrows, row_ptrs[batch_row_begin], row_stride, null_gidx_value); - - dh::safe_cuda(cudaGetLastError()); - dh::safe_cuda(cudaDeviceSynchronize()); } // free the memory that is no longer needed @@ -965,7 +970,7 @@ class GPUHistMakerSpecialised{ void Update(HostDeviceVector* gpair, DMatrix* dmat, const std::vector& trees) { - monitor_.Start("Update", dist_.Devices()); + monitor_.StartCuda("Update"); // rescale learning rate according to size of trees float lr = param_.learning_rate; param_.learning_rate = lr / trees.size(); @@ -980,7 +985,7 @@ class GPUHistMakerSpecialised{ LOG(FATAL) << "Exception in gpu_hist: " << e.what() << std::endl; } param_.learning_rate = lr; - monitor_.Stop("Update", dist_.Devices()); + monitor_.StopCuda("Update"); } void InitDataOnce(DMatrix* dmat) { @@ -1010,17 +1015,17 @@ class GPUHistMakerSpecialised{ }); // Find the cuts. - monitor_.Start("Quantiles", dist_.Devices()); + monitor_.StartCuda("Quantiles"); common::DeviceSketch(batch, *info_, param_, &hmat_, hist_maker_param_.gpu_batch_nrows); n_bins_ = hmat_.row_ptr.back(); - monitor_.Stop("Quantiles", dist_.Devices()); + monitor_.StopCuda("Quantiles"); - monitor_.Start("BinningCompression", dist_.Devices()); + monitor_.StartCuda("BinningCompression"); dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr>& shard) { shard->InitCompressedData(hmat_, batch); }); - monitor_.Stop("BinningCompression", dist_.Devices()); + monitor_.StopCuda("BinningCompression"); ++batch_iter; CHECK(batch_iter.AtEnd()) << "External memory not supported"; @@ -1030,16 +1035,16 @@ class GPUHistMakerSpecialised{ void InitData(HostDeviceVector* gpair, DMatrix* dmat) { if (!initialised_) { - monitor_.Start("InitDataOnce", dist_.Devices()); + monitor_.StartCuda("InitDataOnce"); this->InitDataOnce(dmat); - monitor_.Stop("InitDataOnce", dist_.Devices()); + monitor_.StopCuda("InitDataOnce"); } column_sampler_.Init(info_->num_col_, param_.colsample_bynode, param_.colsample_bylevel, param_.colsample_bytree); // Copy gpair & reset memory - monitor_.Start("InitDataReset", dist_.Devices()); + monitor_.StartCuda("InitDataReset"); gpair->Reshard(dist_); dh::ExecuteIndexShards( @@ -1047,13 +1052,12 @@ class GPUHistMakerSpecialised{ [&](int idx, std::unique_ptr>& shard) { shard->Reset(gpair); }); - monitor_.Stop("InitDataReset", dist_.Devices()); + monitor_.StopCuda("InitDataReset"); } void AllReduceHist(int nidx) { - if (shards_.size() == 1 && !rabit::IsDistributed()) - return; - monitor_.Start("AllReduce"); + if (shards_.size() == 1 && !rabit::IsDistributed()) return; + monitor_.StartCuda("AllReduce"); reducer_.GroupStart(); for (auto& shard : shards_) { @@ -1067,7 +1071,7 @@ class GPUHistMakerSpecialised{ reducer_.GroupEnd(); reducer_.Synchronize(); - monitor_.Stop("AllReduce"); + monitor_.StopCuda("AllReduce"); } /** @@ -1250,12 +1254,12 @@ class GPUHistMakerSpecialised{ RegTree* p_tree) { auto& tree = *p_tree; - monitor_.Start("InitData", dist_.Devices()); + monitor_.StartCuda("InitData"); this->InitData(gpair, p_fmat); - monitor_.Stop("InitData", dist_.Devices()); - monitor_.Start("InitRoot", dist_.Devices()); + monitor_.StopCuda("InitData"); + monitor_.StartCuda("InitRoot"); this->InitRoot(p_tree); - monitor_.Stop("InitRoot", dist_.Devices()); + monitor_.StopCuda("InitRoot"); auto timestamp = qexpand_->size(); auto num_leaves = 1; @@ -1266,9 +1270,9 @@ class GPUHistMakerSpecialised{ if (!candidate.IsValid(param_, num_leaves)) continue; this->ApplySplit(candidate, p_tree); - monitor_.Start("UpdatePosition", dist_.Devices()); + monitor_.StartCuda("UpdatePosition"); this->UpdatePosition(candidate, p_tree); - monitor_.Stop("UpdatePosition", dist_.Devices()); + monitor_.StopCuda("UpdatePosition"); num_leaves++; int left_child_nidx = tree[candidate.nid].LeftChild(); @@ -1277,32 +1281,30 @@ class GPUHistMakerSpecialised{ // Only create child entries if needed if (ExpandEntry::ChildIsValid(param_, tree.GetDepth(left_child_nidx), num_leaves)) { - monitor_.Start("BuildHist", dist_.Devices()); + monitor_.StartCuda("BuildHist"); this->BuildHistLeftRight(candidate.nid, left_child_nidx, right_child_nidx); - monitor_.Stop("BuildHist", dist_.Devices()); + monitor_.StopCuda("BuildHist"); - monitor_.Start("EvaluateSplits", dist_.Devices()); - auto left_child_split = - this->EvaluateSplit(left_child_nidx, p_tree); - auto right_child_split = - this->EvaluateSplit(right_child_nidx, p_tree); + monitor_.StartCuda("EvaluateSplits"); + auto left_child_split = this->EvaluateSplit(left_child_nidx, p_tree); + auto right_child_split = this->EvaluateSplit(right_child_nidx, p_tree); qexpand_->push(ExpandEntry(left_child_nidx, - tree.GetDepth(left_child_nidx), left_child_split, - timestamp++)); + tree.GetDepth(left_child_nidx), + left_child_split, timestamp++)); qexpand_->push(ExpandEntry(right_child_nidx, - tree.GetDepth(right_child_nidx), right_child_split, - timestamp++)); - monitor_.Stop("EvaluateSplits", dist_.Devices()); + tree.GetDepth(right_child_nidx), + right_child_split, timestamp++)); + monitor_.StopCuda("EvaluateSplits"); } } } bool UpdatePredictionCache( const DMatrix* data, HostDeviceVector* p_out_preds) { - monitor_.Start("UpdatePredictionCache", dist_.Devices()); if (shards_.empty() || p_last_fmat_ == nullptr || p_last_fmat_ != data) return false; + monitor_.StartCuda("UpdatePredictionCache"); p_out_preds->Reshard(dist_.Devices()); dh::ExecuteIndexShards( &shards_, @@ -1310,7 +1312,7 @@ class GPUHistMakerSpecialised{ shard->UpdatePredictionCache( p_out_preds->DevicePointer(shard->device_id_)); }); - monitor_.Stop("UpdatePredictionCache", dist_.Devices()); + monitor_.StopCuda("UpdatePredictionCache"); return true; }