diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index d0556a30c..e0d3e41ed 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -19,6 +19,7 @@ #include #include #include +#include "timer.h" #ifdef XGBOOST_USE_NCCL #include "nccl.h" @@ -840,14 +841,17 @@ void Gather(int device_idx, T *out, const T *in, const int *instId, int nVals) { */ class AllReducer { - bool initialised; + bool initialised_; + bool debug_verbose_; + size_t allreduce_bytes_; // Keep statistics of the number of bytes communicated + size_t allreduce_calls_; // Keep statistics of the number of reduce calls #ifdef XGBOOST_USE_NCCL std::vector comms; std::vector streams; std::vector device_ordinals; #endif public: - AllReducer() : initialised(false) {} + AllReducer() : initialised_(false),debug_verbose_(false) {} /** * \fn void Init(const std::vector &device_ordinals) @@ -858,8 +862,10 @@ class AllReducer { * \param device_ordinals The device ordinals. */ - void Init(const std::vector &device_ordinals) { + void Init(const std::vector &device_ordinals, bool debug_verbose) { #ifdef XGBOOST_USE_NCCL + /** \brief this >monitor . init. */ + this->debug_verbose_ = debug_verbose; this->device_ordinals = device_ordinals; comms.resize(device_ordinals.size()); dh::safe_nccl(ncclCommInitAll(comms.data(), @@ -870,7 +876,7 @@ class AllReducer { safe_cuda(cudaSetDevice(device_ordinals[i])); safe_cuda(cudaStreamCreate(&streams[i])); } - initialised = true; + initialised_ = true; #else CHECK_EQ(device_ordinals.size(), 1) << "XGBoost must be compiled with NCCL to use more than one GPU."; @@ -878,7 +884,7 @@ class AllReducer { } ~AllReducer() { #ifdef XGBOOST_USE_NCCL - if (initialised) { + if (initialised_) { for (auto &stream : streams) { dh::safe_cuda(cudaStreamDestroy(stream)); } @@ -886,6 +892,11 @@ class AllReducer { ncclCommDestroy(comm); } } + if (debug_verbose_) { + LOG(CONSOLE) << "======== NCCL Statistics========"; + LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_; + LOG(CONSOLE) << "AllReduce total MB communicated: " << allreduce_bytes_/1000000; + } #endif } @@ -920,11 +931,16 @@ class AllReducer { void AllReduceSum(int communication_group_idx, const double *sendbuff, double *recvbuff, int count) { #ifdef XGBOOST_USE_NCCL - CHECK(initialised); + CHECK(initialised_); dh::safe_cuda(cudaSetDevice(device_ordinals.at(communication_group_idx))); dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclDouble, ncclSum, comms.at(communication_group_idx), streams.at(communication_group_idx))); + if(communication_group_idx == 0) + { + allreduce_bytes_ += count * sizeof(double); + allreduce_calls_ += 1; + } #endif } @@ -942,7 +958,7 @@ class AllReducer { void AllReduceSum(int communication_group_idx, const int64_t *sendbuff, int64_t *recvbuff, int count) { #ifdef XGBOOST_USE_NCCL - CHECK(initialised); + CHECK(initialised_); dh::safe_cuda(cudaSetDevice(device_ordinals[communication_group_idx])); dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclInt64, ncclSum, @@ -989,27 +1005,6 @@ class SaveCudaContext { } }; -/** - * \brief Executes some operation on each element of the input vector, using a - * single controlling thread for each element. - * - * \tparam T Generic type parameter. - * \tparam FunctionT Type of the function t. - * \param shards The shards. - * \param f The func_t to process. - */ - -template -void ExecuteShards(std::vector *shards, FunctionT f) { - SaveCudaContext { - [&](){ -#pragma omp parallel for schedule(static, 1) if (shards->size() > 1) - for (int shard = 0; shard < shards->size(); ++shard) { - f(shards->at(shard)); - } - }}; -} - /** * \brief Executes some operation on each element of the input vector, using a * single controlling thread for each element. In addition, passes the shard index @@ -1023,13 +1018,12 @@ void ExecuteShards(std::vector *shards, FunctionT f) { template void ExecuteIndexShards(std::vector *shards, FunctionT f) { - SaveCudaContext { - [&](){ + SaveCudaContext{[&]() { #pragma omp parallel for schedule(static, 1) if (shards->size() > 1) - for (int shard = 0; shard < shards->size(); ++shard) { - f(shard, shards->at(shard)); - } - }}; + for (int shard = 0; shard < shards->size(); ++shard) { + f(shard, shards->at(shard)); + } + }}; } /** diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 7c9e76059..7daf7fe0d 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -358,7 +358,7 @@ struct GPUSketcher { }); // compute sketches for each shard - dh::ExecuteShards(&shards_, [&](std::unique_ptr& shard) { + dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) { shard->Init(batch, info); shard->Sketch(batch, info); }); diff --git a/src/common/host_device_vector.cu b/src/common/host_device_vector.cu index 12811daa7..9cb39d552 100644 --- a/src/common/host_device_vector.cu +++ b/src/common/host_device_vector.cu @@ -275,7 +275,7 @@ struct HostDeviceVectorImpl { (end - begin) * sizeof(T), cudaMemcpyDeviceToHost)); } else { - dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { + dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { shard.ScatterFrom(begin.get()); }); } @@ -288,7 +288,7 @@ struct HostDeviceVectorImpl { data_h_.size() * sizeof(T), cudaMemcpyHostToDevice)); } else { - dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { shard.GatherTo(begin); }); + dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { shard.GatherTo(begin); }); } } @@ -296,7 +296,7 @@ struct HostDeviceVectorImpl { if (perm_h_.CanWrite()) { std::fill(data_h_.begin(), data_h_.end(), v); } else { - dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { shard.Fill(v); }); + dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { shard.Fill(v); }); } } @@ -323,7 +323,7 @@ struct HostDeviceVectorImpl { if (perm_h_.CanWrite()) { std::copy(other.begin(), other.end(), data_h_.begin()); } else { - dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { + dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { shard.ScatterFrom(other.data()); }); } @@ -334,7 +334,7 @@ struct HostDeviceVectorImpl { if (perm_h_.CanWrite()) { std::copy(other.begin(), other.end(), data_h_.begin()); } else { - dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { + dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { shard.ScatterFrom(other.begin()); }); } @@ -387,14 +387,14 @@ struct HostDeviceVectorImpl { if (perm_h_.CanAccess(access)) { return; } if (perm_h_.CanRead()) { // data is present, just need to deny access to the device - dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { + dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { shard.perm_d_.DenyComplementary(access); }); perm_h_.Grant(access); return; } if (data_h_.size() != size_d_) { data_h_.resize(size_d_); } - dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { + dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { shard.LazySyncHost(access); }); perm_h_.Grant(access); diff --git a/src/common/timer.h b/src/common/timer.h index 4a7f3cf87..e4faff0cb 100644 --- a/src/common/timer.h +++ b/src/common/timer.h @@ -45,9 +45,13 @@ struct Timer { */ struct Monitor { + struct Statistics { + Timer timer; + size_t count{0}; + }; bool debug_verbose = false; std::string label = ""; - std::map timer_map; + std::map statistics_map; Timer self_timer; Monitor() { self_timer.Start(); } @@ -56,35 +60,46 @@ struct Monitor { if (!debug_verbose) return; LOG(CONSOLE) << "======== Monitor: " << label << " ========"; - for (auto &kv : timer_map) { - kv.second.PrintElapsed(kv.first); + for (auto &kv : statistics_map) { + LOG(CONSOLE) << kv.first << ": " << kv.second.timer.ElapsedSeconds() + << "s, " << kv.second.count << " calls @ " + << std::chrono::duration_cast( + kv.second.timer.elapsed / kv.second.count) + .count() + << "us"; } self_timer.Stop(); - self_timer.PrintElapsed(label + " Lifetime"); } void Init(std::string label, bool debug_verbose) { this->debug_verbose = debug_verbose; this->label = label; } - void Start(const std::string &name) { timer_map[name].Start(); } + void Start(const std::string &name) { statistics_map[name].timer.Start(); } void Start(const std::string &name, GPUSet devices) { if (debug_verbose) { #ifdef __CUDACC__ -#include "device_helpers.cuh" - dh::SynchronizeNDevices(devices); + for (auto device : devices) { + cudaSetDevice(device); + cudaDeviceSynchronize(); + } #endif } - timer_map[name].Start(); + statistics_map[name].timer.Start(); + } + void Stop(const std::string &name) { + statistics_map[name].timer.Stop(); + statistics_map[name].count++; } - void Stop(const std::string &name) { timer_map[name].Stop(); } void Stop(const std::string &name, GPUSet devices) { if (debug_verbose) { #ifdef __CUDACC__ -#include "device_helpers.cuh" - dh::SynchronizeNDevices(devices); + for (auto device : devices) { + cudaSetDevice(device); + cudaDeviceSynchronize(); + } #endif } - timer_map[name].Stop(); + this->Stop(name); } }; } // namespace common diff --git a/src/linear/updater_gpu_coordinate.cu b/src/linear/updater_gpu_coordinate.cu index 19241f04f..e8f43310c 100644 --- a/src/linear/updater_gpu_coordinate.cu +++ b/src/linear/updater_gpu_coordinate.cu @@ -258,7 +258,7 @@ class GPUCoordinateUpdater : public LinearUpdater { monitor.Start("UpdateGpair"); // Update gpair - dh::ExecuteShards(&shards, [&](std::unique_ptr &shard) { + dh::ExecuteIndexShards(&shards, [&](int idx, std::unique_ptr& shard) { shard->UpdateGpair(in_gpair->ConstHostVector(), model->param); }); monitor.Stop("UpdateGpair"); @@ -300,7 +300,7 @@ class GPUCoordinateUpdater : public LinearUpdater { model->bias()[group_idx] += dbias; // Update residual - dh::ExecuteShards(&shards, [&](std::unique_ptr &shard) { + dh::ExecuteIndexShards(&shards, [&](int idx, std::unique_ptr& shard) { shard->UpdateBiasResidual(dbias, group_idx, model->param.num_output_group); }); @@ -324,7 +324,7 @@ class GPUCoordinateUpdater : public LinearUpdater { param.reg_lambda_denorm)); w += dw; - dh::ExecuteShards(&shards, [&](std::unique_ptr &shard) { + dh::ExecuteIndexShards(&shards, [&](int idx, std::unique_ptr& shard) { shard->UpdateResidual(dw, group_idx, model->param.num_output_group, fidx); }); } diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index d3930f7d3..5ac8cbdf7 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -337,10 +337,10 @@ class GPUPredictor : public xgboost::Predictor { std::vector device_offsets; DeviceOffsets(batch.offset, &device_offsets); batch.data.Reshard(GPUDistribution::Explicit(devices_, device_offsets)); - dh::ExecuteShards(&shards, [&](DeviceShard& shard){ - shard.PredictInternal(batch, dmat->Info(), out_preds, model, h_tree_segments, - h_nodes, tree_begin, tree_end); - }); + dh::ExecuteIndexShards(&shards, [&](int idx, DeviceShard& shard) { + shard.PredictInternal(batch, dmat->Info(), out_preds, model, + h_tree_segments, h_nodes, tree_begin, tree_end); + }); i_batch++; } } diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 99d795063..2c11269f4 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -587,23 +587,6 @@ struct DeviceShard { return best_split; } - /** \brief Builds both left and right hist with subtraction trick if possible. - */ - void BuildHistWithSubtractionTrick(int nidx_parent, int nidx_left, - int nidx_right) { - auto smallest_nidx = - ridx_segments[nidx_left].Size() < ridx_segments[nidx_right].Size() - ? nidx_left - : nidx_right; - auto largest_nidx = smallest_nidx == nidx_left ? nidx_right : nidx_left; - this->BuildHist(smallest_nidx); - if (this->CanDoSubtractionTrick(nidx_parent, smallest_nidx, largest_nidx)) { - this->SubtractionTrick(nidx_parent, smallest_nidx, largest_nidx); - } else { - this->BuildHist(largest_nidx); - } - } - void BuildHist(int nidx) { hist.AllocateHistogram(nidx); hist_builder->Build(this, nidx); @@ -954,7 +937,7 @@ class GPUHistMaker : public TreeUpdater { device_list_[index] = device_id; } - reducer_.Init(device_list_); + reducer_.Init(device_list_, param_.debug_verbose); auto batch_iter = dmat->GetRowBatches().begin(); const SparsePage& batch = *batch_iter; @@ -976,7 +959,7 @@ class GPUHistMaker : public TreeUpdater { monitor_.Stop("Quantiles", dist_.Devices()); monitor_.Start("BinningCompression", dist_.Devices()); - dh::ExecuteShards(&shards_, [&](std::unique_ptr& shard) { + dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) { shard->InitCompressedData(hmat_, batch); }); monitor_.Stop("BinningCompression", dist_.Devices()); @@ -1000,7 +983,7 @@ class GPUHistMaker : public TreeUpdater { monitor_.Start("InitDataReset", dist_.Devices()); gpair->Reshard(dist_); - dh::ExecuteShards(&shards_, [&](std::unique_ptr& shard) { + dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) { shard->Reset(gpair); }); monitor_.Stop("InitDataReset", dist_.Devices()); @@ -1009,34 +992,66 @@ class GPUHistMaker : public TreeUpdater { void AllReduceHist(int nidx) { if (shards_.size() == 1) return; - reducer_.GroupStart(); - for (auto& shard : shards_) { + monitor_.Start("AllReduce"); + dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) { auto d_node_hist = shard->hist.GetNodeHistogram(nidx).data(); reducer_.AllReduceSum( dist_.Devices().Index(shard->device_id_), reinterpret_cast(d_node_hist), reinterpret_cast(d_node_hist), n_bins_ * (sizeof(GradientPairSumT) / sizeof(GradientPairSumT::ValueT))); - } - reducer_.GroupEnd(); - - reducer_.Synchronize(); + }); + monitor_.Stop("AllReduce"); } /** * \brief Build GPU local histograms for the left and right child of some parent node */ void BuildHistLeftRight(int nidx_parent, int nidx_left, int nidx_right) { - // If one GPU - if (shards_.size() == 1) { - shards_.back()->BuildHistWithSubtractionTrick(nidx_parent, nidx_left, nidx_right); - } else { - dh::ExecuteShards(&shards_, [&](std::unique_ptr& shard) { - shard->BuildHist(nidx_left); - shard->BuildHist(nidx_right); + size_t left_node_max_elements = 0; + size_t right_node_max_elements = 0; + for (auto& shard : shards_) { + left_node_max_elements = (std::max)( + left_node_max_elements, shard->ridx_segments[nidx_left].Size()); + right_node_max_elements = (std::max)( + right_node_max_elements, shard->ridx_segments[nidx_right].Size()); + } + + auto build_hist_nidx = nidx_left; + auto subtraction_trick_nidx = nidx_right; + + if (right_node_max_elements < left_node_max_elements) { + build_hist_nidx = nidx_right; + subtraction_trick_nidx = nidx_left; + } + + // Build histogram for node with the smallest number of training examples + dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) { + shard->BuildHist(build_hist_nidx); + }); + + this->AllReduceHist(build_hist_nidx); + + // Check whether we can use the subtraction trick to calculate the other + bool do_subtraction_trick = true; + for (auto& shard : shards_) { + do_subtraction_trick &= shard->CanDoSubtractionTrick( + nidx_parent, build_hist_nidx, subtraction_trick_nidx); + } + + if (do_subtraction_trick) { + // Calculate other histogram using subtraction trick + dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) { + shard->SubtractionTrick(nidx_parent, build_hist_nidx, + subtraction_trick_nidx); }); - this->AllReduceHist(nidx_left); - this->AllReduceHist(nidx_right); + } else { + // Calculate other histogram manually + dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) { + shard->BuildHist(subtraction_trick_nidx); + }); + + this->AllReduceHist(subtraction_trick_nidx); } } @@ -1061,7 +1076,7 @@ class GPUHistMaker : public TreeUpdater { std::accumulate(tmp_sums.begin(), tmp_sums.end(), GradientPair()); // Generate root histogram - dh::ExecuteShards(&shards_, [&](std::unique_ptr& shard) { + dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) { shard->BuildHist(root_nidx); }); @@ -1107,7 +1122,7 @@ class GPUHistMaker : public TreeUpdater { } auto is_dense = info_->num_nonzero_ == info_->num_row_ * info_->num_col_; - dh::ExecuteShards(&shards_, [&](std::unique_ptr& shard) { + dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) { shard->UpdatePosition(nidx, left_nidx, right_nidx, fidx, split_gidx, default_dir_left, is_dense, fidx_begin, fidx_end); @@ -1153,7 +1168,6 @@ class GPUHistMaker : public TreeUpdater { shard->node_sum_gradients[parent.LeftChild()] = candidate.split.left_sum; shard->node_sum_gradients[parent.RightChild()] = candidate.split.right_sum; } - this->UpdatePosition(candidate, p_tree); } void UpdateTree(HostDeviceVector* gpair, DMatrix* p_fmat, @@ -1175,9 +1189,10 @@ class GPUHistMaker : public TreeUpdater { qexpand_->pop(); if (!candidate.IsValid(param_, num_leaves)) continue; - monitor_.Start("ApplySplit", dist_.Devices()); this->ApplySplit(candidate, p_tree); - monitor_.Stop("ApplySplit", dist_.Devices()); + monitor_.Start("UpdatePosition", dist_.Devices()); + this->UpdatePosition(candidate, p_tree); + monitor_.Stop("UpdatePosition", dist_.Devices()); num_leaves++; int left_child_nidx = tree[candidate.nid].LeftChild(); @@ -1213,7 +1228,7 @@ class GPUHistMaker : public TreeUpdater { if (shards_.empty() || p_last_fmat_ == nullptr || p_last_fmat_ != data) return false; p_out_preds->Reshard(dist_.Devices()); - dh::ExecuteShards(&shards_, [&](std::unique_ptr& shard) { + dh::ExecuteIndexShards(&shards_, [&](int idx, std::unique_ptr& shard) { shard->UpdatePredictionCache( p_out_preds->DevicePointer(shard->device_id_)); }); diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 947e8b11c..cd4096fca 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -375,6 +375,7 @@ TEST(GpuHist, ApplySplit) { hist_maker.info_ = &info; hist_maker.ApplySplit(candidate_entry, &tree); + hist_maker.UpdatePosition(candidate_entry, &tree); ASSERT_FALSE(tree[nid].IsLeaf());