From a1ec7b1716f78d333ead277cc17b3c04097a2b7b Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Wed, 4 Apr 2018 14:21:48 +1200 Subject: [PATCH] Change reduce operation from thrust to cub. Fix for cuda 9.1 error (#3218) * Change reduce operation from thrust to cub. Fix for cuda 9.1 runtime error * Unit test sum reduce --- src/common/device_helpers.cuh | 23 +++++++++++ src/tree/updater_gpu_hist.cu | 53 ++++++++++++------------- tests/cpp/common/test_device_helpers.cu | 13 +++++- tests/cpp/metric/test_rank_metric.cc | 2 +- 4 files changed, 61 insertions(+), 30 deletions(-) diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index c10c6a4eb..a171d15d3 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -797,6 +797,29 @@ void sumReduction(dh::CubMemory &tmp_mem, dh::dvec &in, dh::dvec &out, in.data(), out.data(), nVals)); } +/** +* @brief Helper function to perform device-wide sum-reduction, returns to the +* host +* @param tmp_mem cub temporary memory info +* @param in the input array to be reduced +* @param nVals number of elements in the input array +*/ +template +T sumReduction(dh::CubMemory &tmp_mem, T *in, int nVals) { + size_t tmpSize; + dh::safe_cuda(cub::DeviceReduce::Sum(nullptr, tmpSize, in, in, nVals)); + // Allocate small extra memory for the return value + tmp_mem.LazyAllocate(tmpSize + sizeof(T)); + auto ptr = reinterpret_cast(tmp_mem.d_temp_storage) + 1; + dh::safe_cuda(cub::DeviceReduce::Sum( + reinterpret_cast(ptr), tmpSize, in, + reinterpret_cast(tmp_mem.d_temp_storage), nVals)); + T sum; + dh::safe_cuda(cudaMemcpy(&sum, tmp_mem.d_temp_storage, sizeof(T), + cudaMemcpyDeviceToHost)); + return sum; +} + /** * @brief Fill a given constant value across all elements in the buffer * @param out the buffer to be filled diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index b951b8a9e..c3ff507a7 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -217,9 +217,11 @@ struct CalcWeightTrainParam { float max_delta_step; float learning_rate; __host__ __device__ CalcWeightTrainParam(const TrainParam& p) - : min_child_weight(p.min_child_weight), reg_alpha(p.reg_alpha), - reg_lambda(p.reg_lambda), max_delta_step(p.max_delta_step), - learning_rate(p.learning_rate) {} + : min_child_weight(p.min_child_weight), + reg_alpha(p.reg_alpha), + reg_lambda(p.reg_lambda), + max_delta_step(p.max_delta_step), + learning_rate(p.learning_rate) {} }; // Manage memory for a single GPU @@ -502,10 +504,9 @@ struct DeviceShard { void UpdatePredictionCache(bst_float* out_preds_d) { dh::safe_cuda(cudaSetDevice(device_idx)); if (!prediction_cache_initialised) { - dh::safe_cuda(cudaMemcpy - (prediction_cache.data(), &out_preds_d[row_begin_idx], - prediction_cache.size() * sizeof(bst_float), - cudaMemcpyDefault)); + dh::safe_cuda(cudaMemcpy( + prediction_cache.data(), &out_preds_d[row_begin_idx], + prediction_cache.size() * sizeof(bst_float), cudaMemcpyDefault)); } prediction_cache_initialised = true; @@ -518,18 +519,17 @@ struct DeviceShard { auto d_node_sum_gradients = node_sum_gradients_d.data(); auto d_prediction_cache = prediction_cache.data(); - dh::launch_n(device_idx, prediction_cache.size(), - [=] __device__(int local_idx) { - int pos = d_position[local_idx]; - bst_float weight = CalcWeight(param_d, d_node_sum_gradients[pos]); - d_prediction_cache[d_ridx[local_idx]] += - weight * param_d.learning_rate; - }); + dh::launch_n( + device_idx, prediction_cache.size(), [=] __device__(int local_idx) { + int pos = d_position[local_idx]; + bst_float weight = CalcWeight(param_d, d_node_sum_gradients[pos]); + d_prediction_cache[d_ridx[local_idx]] += + weight * param_d.learning_rate; + }); - dh::safe_cuda(cudaMemcpy - (&out_preds_d[row_begin_idx], prediction_cache.data(), - prediction_cache.size() * sizeof(bst_float), - cudaMemcpyDefault)); + dh::safe_cuda(cudaMemcpy( + &out_preds_d[row_begin_idx], prediction_cache.data(), + prediction_cache.size() * sizeof(bst_float), cudaMemcpyDefault)); } }; @@ -748,11 +748,10 @@ class GPUHistMaker : public TreeUpdater { #pragma omp parallel { auto cpu_thread_id = omp_get_thread_num(); - dh::safe_cuda(cudaSetDevice(shards[cpu_thread_id]->device_idx)); - tmp_sums[cpu_thread_id] = - thrust::reduce(thrust::cuda::par(shards[cpu_thread_id]->temp_memory), - shards[cpu_thread_id]->gpair.tbegin(), - shards[cpu_thread_id]->gpair.tend()); + auto& shard = shards[cpu_thread_id]; + dh::safe_cuda(cudaSetDevice(shard->device_idx)); + tmp_sums[cpu_thread_id] = dh::sumReduction( + shard->temp_memory, shard->gpair.data(), shard->gpair.size()); } auto sum_gradient = std::accumulate(tmp_sums.begin(), tmp_sums.end(), bst_gpair_precise()); @@ -909,15 +908,15 @@ class GPUHistMaker : public TreeUpdater { omp_set_num_threads(nthread); } - bool UpdatePredictionCache - (const DMatrix* data, HostDeviceVector* p_out_preds) override { + bool UpdatePredictionCache( + const DMatrix* data, HostDeviceVector* p_out_preds) override { monitor.Start("UpdatePredictionCache", dList); if (shards.empty() || p_last_fmat_ == nullptr || p_last_fmat_ != data) return false; - bst_float *out_preds_d = p_out_preds->ptr_d(param.gpu_id); + bst_float* out_preds_d = p_out_preds->ptr_d(param.gpu_id); - #pragma omp parallel for schedule(static, 1) +#pragma omp parallel for schedule(static, 1) for (int shard = 0; shard < shards.size(); ++shard) { shards[shard]->UpdatePredictionCache(out_preds_d); } diff --git a/tests/cpp/common/test_device_helpers.cu b/tests/cpp/common/test_device_helpers.cu index c51afb3f8..ad00328f1 100644 --- a/tests/cpp/common/test_device_helpers.cu +++ b/tests/cpp/common/test_device_helpers.cu @@ -5,8 +5,8 @@ #include #include #include "../../../src/common/device_helpers.cuh" -#include "gtest/gtest.h" #include "../../../src/common/timer.h" +#include "gtest/gtest.h" void CreateTestData(xgboost::bst_uint num_rows, int max_row_size, thrust::host_vector *row_ptr, @@ -38,7 +38,8 @@ void SpeedTest() { xgboost::common::Timer t; dh::TransformLbs( - 0, &temp_memory, h_rows.size(), dh::raw(row_ptr), row_ptr.size() - 1, false, + 0, &temp_memory, h_rows.size(), dh::raw(row_ptr), row_ptr.size() - 1, + false, [=] __device__(size_t idx, size_t ridx) { d_output_row[idx] = ridx; }); dh::safe_cuda(cudaDeviceSynchronize()); @@ -76,4 +77,12 @@ void TestLbs() { } } } + TEST(cub_lbs, Test) { TestLbs(); } + +TEST(sumReduce, Test) { + thrust::device_vector data(100, 1.0f); + dh::CubMemory temp; + auto sum = dh::sumReduction(temp, dh::raw(data), data.size()); + ASSERT_NEAR(sum, 100.0f, 1e-5); +} diff --git a/tests/cpp/metric/test_rank_metric.cc b/tests/cpp/metric/test_rank_metric.cc index f0f7a0090..1d271521b 100644 --- a/tests/cpp/metric/test_rank_metric.cc +++ b/tests/cpp/metric/test_rank_metric.cc @@ -39,7 +39,7 @@ TEST(Metric, AUCPR) { 0.5f, 0.001f); EXPECT_NEAR( GetMetricEval(metric, - {0.4f, 0.2f, 0.9f, 0.1f, 0.2f, 0.4f, 0.1f, 0.1f, 0.2f, 0.1}, + {0.4f, 0.2f, 0.9f, 0.1f, 0.2f, 0.4f, 0.1f, 0.1f, 0.2f, 0.1f}, {0, 0, 0, 0, 0, 1, 0, 0, 1, 1}), 0.2908445f, 0.001f); EXPECT_NEAR(GetMetricEval(