Change reduce operation from thrust to cub. Fix for cuda 9.1 error (#3218)

* Change reduce operation from thrust to cub. Fix for cuda 9.1 runtime error

* Unit test sum reduce
This commit is contained in:
Rory Mitchell 2018-04-04 14:21:48 +12:00 committed by GitHub
parent 017acf54d9
commit a1ec7b1716
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 61 additions and 30 deletions

View File

@ -797,6 +797,29 @@ void sumReduction(dh::CubMemory &tmp_mem, dh::dvec<T> &in, dh::dvec<T> &out,
in.data(), out.data(), nVals));
}
/**
* @brief Helper function to perform device-wide sum-reduction, returns to the
* host
* @param tmp_mem cub temporary memory info
* @param in the input array to be reduced
* @param nVals number of elements in the input array
*/
template <typename T>
T sumReduction(dh::CubMemory &tmp_mem, T *in, int nVals) {
size_t tmpSize;
dh::safe_cuda(cub::DeviceReduce::Sum(nullptr, tmpSize, in, in, nVals));
// Allocate small extra memory for the return value
tmp_mem.LazyAllocate(tmpSize + sizeof(T));
auto ptr = reinterpret_cast<T *>(tmp_mem.d_temp_storage) + 1;
dh::safe_cuda(cub::DeviceReduce::Sum(
reinterpret_cast<void *>(ptr), tmpSize, in,
reinterpret_cast<T *>(tmp_mem.d_temp_storage), nVals));
T sum;
dh::safe_cuda(cudaMemcpy(&sum, tmp_mem.d_temp_storage, sizeof(T),
cudaMemcpyDeviceToHost));
return sum;
}
/**
* @brief Fill a given constant value across all elements in the buffer
* @param out the buffer to be filled

View File

@ -217,9 +217,11 @@ struct CalcWeightTrainParam {
float max_delta_step;
float learning_rate;
__host__ __device__ CalcWeightTrainParam(const TrainParam& p)
: min_child_weight(p.min_child_weight), reg_alpha(p.reg_alpha),
reg_lambda(p.reg_lambda), max_delta_step(p.max_delta_step),
learning_rate(p.learning_rate) {}
: min_child_weight(p.min_child_weight),
reg_alpha(p.reg_alpha),
reg_lambda(p.reg_lambda),
max_delta_step(p.max_delta_step),
learning_rate(p.learning_rate) {}
};
// Manage memory for a single GPU
@ -502,10 +504,9 @@ struct DeviceShard {
void UpdatePredictionCache(bst_float* out_preds_d) {
dh::safe_cuda(cudaSetDevice(device_idx));
if (!prediction_cache_initialised) {
dh::safe_cuda(cudaMemcpy
(prediction_cache.data(), &out_preds_d[row_begin_idx],
prediction_cache.size() * sizeof(bst_float),
cudaMemcpyDefault));
dh::safe_cuda(cudaMemcpy(
prediction_cache.data(), &out_preds_d[row_begin_idx],
prediction_cache.size() * sizeof(bst_float), cudaMemcpyDefault));
}
prediction_cache_initialised = true;
@ -518,18 +519,17 @@ struct DeviceShard {
auto d_node_sum_gradients = node_sum_gradients_d.data();
auto d_prediction_cache = prediction_cache.data();
dh::launch_n(device_idx, prediction_cache.size(),
[=] __device__(int local_idx) {
int pos = d_position[local_idx];
bst_float weight = CalcWeight(param_d, d_node_sum_gradients[pos]);
d_prediction_cache[d_ridx[local_idx]] +=
weight * param_d.learning_rate;
});
dh::launch_n(
device_idx, prediction_cache.size(), [=] __device__(int local_idx) {
int pos = d_position[local_idx];
bst_float weight = CalcWeight(param_d, d_node_sum_gradients[pos]);
d_prediction_cache[d_ridx[local_idx]] +=
weight * param_d.learning_rate;
});
dh::safe_cuda(cudaMemcpy
(&out_preds_d[row_begin_idx], prediction_cache.data(),
prediction_cache.size() * sizeof(bst_float),
cudaMemcpyDefault));
dh::safe_cuda(cudaMemcpy(
&out_preds_d[row_begin_idx], prediction_cache.data(),
prediction_cache.size() * sizeof(bst_float), cudaMemcpyDefault));
}
};
@ -748,11 +748,10 @@ class GPUHistMaker : public TreeUpdater {
#pragma omp parallel
{
auto cpu_thread_id = omp_get_thread_num();
dh::safe_cuda(cudaSetDevice(shards[cpu_thread_id]->device_idx));
tmp_sums[cpu_thread_id] =
thrust::reduce(thrust::cuda::par(shards[cpu_thread_id]->temp_memory),
shards[cpu_thread_id]->gpair.tbegin(),
shards[cpu_thread_id]->gpair.tend());
auto& shard = shards[cpu_thread_id];
dh::safe_cuda(cudaSetDevice(shard->device_idx));
tmp_sums[cpu_thread_id] = dh::sumReduction(
shard->temp_memory, shard->gpair.data(), shard->gpair.size());
}
auto sum_gradient =
std::accumulate(tmp_sums.begin(), tmp_sums.end(), bst_gpair_precise());
@ -909,15 +908,15 @@ class GPUHistMaker : public TreeUpdater {
omp_set_num_threads(nthread);
}
bool UpdatePredictionCache
(const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) override {
bool UpdatePredictionCache(
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) override {
monitor.Start("UpdatePredictionCache", dList);
if (shards.empty() || p_last_fmat_ == nullptr || p_last_fmat_ != data)
return false;
bst_float *out_preds_d = p_out_preds->ptr_d(param.gpu_id);
bst_float* out_preds_d = p_out_preds->ptr_d(param.gpu_id);
#pragma omp parallel for schedule(static, 1)
#pragma omp parallel for schedule(static, 1)
for (int shard = 0; shard < shards.size(); ++shard) {
shards[shard]->UpdatePredictionCache(out_preds_d);
}

View File

@ -5,8 +5,8 @@
#include <thrust/device_vector.h>
#include <xgboost/base.h>
#include "../../../src/common/device_helpers.cuh"
#include "gtest/gtest.h"
#include "../../../src/common/timer.h"
#include "gtest/gtest.h"
void CreateTestData(xgboost::bst_uint num_rows, int max_row_size,
thrust::host_vector<int> *row_ptr,
@ -38,7 +38,8 @@ void SpeedTest() {
xgboost::common::Timer t;
dh::TransformLbs(
0, &temp_memory, h_rows.size(), dh::raw(row_ptr), row_ptr.size() - 1, false,
0, &temp_memory, h_rows.size(), dh::raw(row_ptr), row_ptr.size() - 1,
false,
[=] __device__(size_t idx, size_t ridx) { d_output_row[idx] = ridx; });
dh::safe_cuda(cudaDeviceSynchronize());
@ -76,4 +77,12 @@ void TestLbs() {
}
}
}
TEST(cub_lbs, Test) { TestLbs(); }
TEST(sumReduce, Test) {
thrust::device_vector<float> data(100, 1.0f);
dh::CubMemory temp;
auto sum = dh::sumReduction(temp, dh::raw(data), data.size());
ASSERT_NEAR(sum, 100.0f, 1e-5);
}

View File

@ -39,7 +39,7 @@ TEST(Metric, AUCPR) {
0.5f, 0.001f);
EXPECT_NEAR(
GetMetricEval(metric,
{0.4f, 0.2f, 0.9f, 0.1f, 0.2f, 0.4f, 0.1f, 0.1f, 0.2f, 0.1},
{0.4f, 0.2f, 0.9f, 0.1f, 0.2f, 0.4f, 0.1f, 0.1f, 0.2f, 0.1f},
{0, 0, 0, 0, 0, 1, 0, 0, 1, 1}),
0.2908445f, 0.001f);
EXPECT_NEAR(GetMetricEval(