diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 69215a535..c849730dc 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -7,6 +7,7 @@ #include #include #include +#include #include "common.h" #include "span.h" @@ -784,6 +785,7 @@ class AllReducer { bool initialised_; size_t allreduce_bytes_; // Keep statistics of the number of bytes communicated size_t allreduce_calls_; // Keep statistics of the number of reduce calls + std::vector host_data; // Used for all reduce on host #ifdef XGBOOST_USE_NCCL std::vector comms; std::vector streams; @@ -1024,6 +1026,42 @@ class AllReducer { return id; } #endif + /** \brief Perform max all reduce operation on the host. This function first + * reduces over omp threads then over nodes using rabit (which is not thread + * safe) using the master thread. Uses naive reduce algorithm for local + * threads, don't expect this to scale.*/ + void HostMaxAllReduce(std::vector *p_data) { + auto &data = *p_data; + // Wait in case some other thread is accessing host_data +#pragma omp barrier + // Reset shared buffer +#pragma omp single + { + host_data.resize(data.size()); + std::fill(host_data.begin(), host_data.end(), size_t(0)); + } + // Threads update shared array + for (auto i = 0ull; i < data.size(); i++) { +#pragma omp critical + { host_data[i] = std::max(host_data[i], data[i]); } + } + // Wait until all threads are finished +#pragma omp barrier + + // One thread performs all reduce across distributed nodes +#pragma omp master + { + rabit::Allreduce(host_data.data(), + host_data.size()); + } + +#pragma omp barrier + + // Threads can now read back all reduced values + for (auto i = 0ull; i < data.size(); i++) { + data[i] = host_data[i]; + } + } }; /** @@ -1044,7 +1082,7 @@ void ExecuteIndexShards(std::vector *shards, FunctionT f) { bool dynamic = omp_get_dynamic(); omp_set_dynamic(false); const long shards_size = static_cast(shards->size()); -#pragma omp parallel for schedule(static, 1) if (shards_size > 1) +#pragma omp parallel for schedule(static, 1) if (shards_size > 1) num_threads(shards_size) for (long shard = 0; shard < shards_size; ++shard) { f(shard, shards->at(shard)); } diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 0a4f0d7f8..b41cb3632 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -1054,15 +1054,17 @@ struct DeviceShard { auto build_hist_nidx = nidx_left; auto subtraction_trick_nidx = nidx_right; - // If we are using a single GPU, build the histogram for the node with the - // fewest training instances - // If we are distributed, don't bother - if (reducer->IsSingleGPU()) { - bool fewer_right = - ridx_segments[nidx_right].Size() < ridx_segments[nidx_left].Size(); - if (fewer_right) { - std::swap(build_hist_nidx, subtraction_trick_nidx); - } + auto left_node_rows = ridx_segments[nidx_left].Size(); + auto right_node_rows = ridx_segments[nidx_right].Size(); + // Decide whether to build the left histogram or right histogram + // Find the largest number of training instances on any given Shard + // Assume this will be the bottleneck and avoid building this node if + // possible + std::vector max_reduce = {left_node_rows, right_node_rows}; + reducer->HostMaxAllReduce(&max_reduce); + bool fewer_right = max_reduce[1] < max_reduce[0]; + if (fewer_right) { + std::swap(build_hist_nidx, subtraction_trick_nidx); } this->BuildHist(build_hist_nidx); diff --git a/tests/cpp/common/test_device_helpers.cu b/tests/cpp/common/test_device_helpers.cu index 0fbf6e5eb..c9cb1e61a 100644 --- a/tests/cpp/common/test_device_helpers.cu +++ b/tests/cpp/common/test_device_helpers.cu @@ -95,3 +95,20 @@ void TestAllocator() { TEST(bulkAllocator, Test) { TestAllocator(); } + + // Test thread safe max reduction +TEST(AllReducer, HostMaxAllReduce) { + dh::AllReducer reducer; + size_t num_threads = 50; + std::vector> thread_data(num_threads); +#pragma omp parallel num_threads(num_threads) + { + int tid = omp_get_thread_num(); + thread_data[tid] = {size_t(tid)}; + reducer.HostMaxAllReduce(&thread_data[tid]); + } + + for (auto data : thread_data) { + ASSERT_EQ(data.front(), num_threads - 1); + } +}