Change reduce operation from thrust to cub. Fix for cuda 9.1 error (#3218)
* Change reduce operation from thrust to cub. Fix for cuda 9.1 runtime error * Unit test sum reduce
This commit is contained in:
parent
017acf54d9
commit
a1ec7b1716
@ -797,6 +797,29 @@ void sumReduction(dh::CubMemory &tmp_mem, dh::dvec<T> &in, dh::dvec<T> &out,
|
||||
in.data(), out.data(), nVals));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Helper function to perform device-wide sum-reduction, returns to the
|
||||
* host
|
||||
* @param tmp_mem cub temporary memory info
|
||||
* @param in the input array to be reduced
|
||||
* @param nVals number of elements in the input array
|
||||
*/
|
||||
template <typename T>
|
||||
T sumReduction(dh::CubMemory &tmp_mem, T *in, int nVals) {
|
||||
size_t tmpSize;
|
||||
dh::safe_cuda(cub::DeviceReduce::Sum(nullptr, tmpSize, in, in, nVals));
|
||||
// Allocate small extra memory for the return value
|
||||
tmp_mem.LazyAllocate(tmpSize + sizeof(T));
|
||||
auto ptr = reinterpret_cast<T *>(tmp_mem.d_temp_storage) + 1;
|
||||
dh::safe_cuda(cub::DeviceReduce::Sum(
|
||||
reinterpret_cast<void *>(ptr), tmpSize, in,
|
||||
reinterpret_cast<T *>(tmp_mem.d_temp_storage), nVals));
|
||||
T sum;
|
||||
dh::safe_cuda(cudaMemcpy(&sum, tmp_mem.d_temp_storage, sizeof(T),
|
||||
cudaMemcpyDeviceToHost));
|
||||
return sum;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Fill a given constant value across all elements in the buffer
|
||||
* @param out the buffer to be filled
|
||||
|
||||
@ -217,9 +217,11 @@ struct CalcWeightTrainParam {
|
||||
float max_delta_step;
|
||||
float learning_rate;
|
||||
__host__ __device__ CalcWeightTrainParam(const TrainParam& p)
|
||||
: min_child_weight(p.min_child_weight), reg_alpha(p.reg_alpha),
|
||||
reg_lambda(p.reg_lambda), max_delta_step(p.max_delta_step),
|
||||
learning_rate(p.learning_rate) {}
|
||||
: min_child_weight(p.min_child_weight),
|
||||
reg_alpha(p.reg_alpha),
|
||||
reg_lambda(p.reg_lambda),
|
||||
max_delta_step(p.max_delta_step),
|
||||
learning_rate(p.learning_rate) {}
|
||||
};
|
||||
|
||||
// Manage memory for a single GPU
|
||||
@ -502,10 +504,9 @@ struct DeviceShard {
|
||||
void UpdatePredictionCache(bst_float* out_preds_d) {
|
||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||
if (!prediction_cache_initialised) {
|
||||
dh::safe_cuda(cudaMemcpy
|
||||
(prediction_cache.data(), &out_preds_d[row_begin_idx],
|
||||
prediction_cache.size() * sizeof(bst_float),
|
||||
cudaMemcpyDefault));
|
||||
dh::safe_cuda(cudaMemcpy(
|
||||
prediction_cache.data(), &out_preds_d[row_begin_idx],
|
||||
prediction_cache.size() * sizeof(bst_float), cudaMemcpyDefault));
|
||||
}
|
||||
prediction_cache_initialised = true;
|
||||
|
||||
@ -518,18 +519,17 @@ struct DeviceShard {
|
||||
auto d_node_sum_gradients = node_sum_gradients_d.data();
|
||||
auto d_prediction_cache = prediction_cache.data();
|
||||
|
||||
dh::launch_n(device_idx, prediction_cache.size(),
|
||||
[=] __device__(int local_idx) {
|
||||
int pos = d_position[local_idx];
|
||||
bst_float weight = CalcWeight(param_d, d_node_sum_gradients[pos]);
|
||||
d_prediction_cache[d_ridx[local_idx]] +=
|
||||
weight * param_d.learning_rate;
|
||||
});
|
||||
dh::launch_n(
|
||||
device_idx, prediction_cache.size(), [=] __device__(int local_idx) {
|
||||
int pos = d_position[local_idx];
|
||||
bst_float weight = CalcWeight(param_d, d_node_sum_gradients[pos]);
|
||||
d_prediction_cache[d_ridx[local_idx]] +=
|
||||
weight * param_d.learning_rate;
|
||||
});
|
||||
|
||||
dh::safe_cuda(cudaMemcpy
|
||||
(&out_preds_d[row_begin_idx], prediction_cache.data(),
|
||||
prediction_cache.size() * sizeof(bst_float),
|
||||
cudaMemcpyDefault));
|
||||
dh::safe_cuda(cudaMemcpy(
|
||||
&out_preds_d[row_begin_idx], prediction_cache.data(),
|
||||
prediction_cache.size() * sizeof(bst_float), cudaMemcpyDefault));
|
||||
}
|
||||
};
|
||||
|
||||
@ -748,11 +748,10 @@ class GPUHistMaker : public TreeUpdater {
|
||||
#pragma omp parallel
|
||||
{
|
||||
auto cpu_thread_id = omp_get_thread_num();
|
||||
dh::safe_cuda(cudaSetDevice(shards[cpu_thread_id]->device_idx));
|
||||
tmp_sums[cpu_thread_id] =
|
||||
thrust::reduce(thrust::cuda::par(shards[cpu_thread_id]->temp_memory),
|
||||
shards[cpu_thread_id]->gpair.tbegin(),
|
||||
shards[cpu_thread_id]->gpair.tend());
|
||||
auto& shard = shards[cpu_thread_id];
|
||||
dh::safe_cuda(cudaSetDevice(shard->device_idx));
|
||||
tmp_sums[cpu_thread_id] = dh::sumReduction(
|
||||
shard->temp_memory, shard->gpair.data(), shard->gpair.size());
|
||||
}
|
||||
auto sum_gradient =
|
||||
std::accumulate(tmp_sums.begin(), tmp_sums.end(), bst_gpair_precise());
|
||||
@ -909,15 +908,15 @@ class GPUHistMaker : public TreeUpdater {
|
||||
omp_set_num_threads(nthread);
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache
|
||||
(const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) override {
|
||||
bool UpdatePredictionCache(
|
||||
const DMatrix* data, HostDeviceVector<bst_float>* p_out_preds) override {
|
||||
monitor.Start("UpdatePredictionCache", dList);
|
||||
if (shards.empty() || p_last_fmat_ == nullptr || p_last_fmat_ != data)
|
||||
return false;
|
||||
|
||||
bst_float *out_preds_d = p_out_preds->ptr_d(param.gpu_id);
|
||||
bst_float* out_preds_d = p_out_preds->ptr_d(param.gpu_id);
|
||||
|
||||
#pragma omp parallel for schedule(static, 1)
|
||||
#pragma omp parallel for schedule(static, 1)
|
||||
for (int shard = 0; shard < shards.size(); ++shard) {
|
||||
shards[shard]->UpdatePredictionCache(out_preds_d);
|
||||
}
|
||||
|
||||
@ -5,8 +5,8 @@
|
||||
#include <thrust/device_vector.h>
|
||||
#include <xgboost/base.h>
|
||||
#include "../../../src/common/device_helpers.cuh"
|
||||
#include "gtest/gtest.h"
|
||||
#include "../../../src/common/timer.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
void CreateTestData(xgboost::bst_uint num_rows, int max_row_size,
|
||||
thrust::host_vector<int> *row_ptr,
|
||||
@ -38,7 +38,8 @@ void SpeedTest() {
|
||||
|
||||
xgboost::common::Timer t;
|
||||
dh::TransformLbs(
|
||||
0, &temp_memory, h_rows.size(), dh::raw(row_ptr), row_ptr.size() - 1, false,
|
||||
0, &temp_memory, h_rows.size(), dh::raw(row_ptr), row_ptr.size() - 1,
|
||||
false,
|
||||
[=] __device__(size_t idx, size_t ridx) { d_output_row[idx] = ridx; });
|
||||
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
@ -76,4 +77,12 @@ void TestLbs() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(cub_lbs, Test) { TestLbs(); }
|
||||
|
||||
TEST(sumReduce, Test) {
|
||||
thrust::device_vector<float> data(100, 1.0f);
|
||||
dh::CubMemory temp;
|
||||
auto sum = dh::sumReduction(temp, dh::raw(data), data.size());
|
||||
ASSERT_NEAR(sum, 100.0f, 1e-5);
|
||||
}
|
||||
|
||||
@ -39,7 +39,7 @@ TEST(Metric, AUCPR) {
|
||||
0.5f, 0.001f);
|
||||
EXPECT_NEAR(
|
||||
GetMetricEval(metric,
|
||||
{0.4f, 0.2f, 0.9f, 0.1f, 0.2f, 0.4f, 0.1f, 0.1f, 0.2f, 0.1},
|
||||
{0.4f, 0.2f, 0.9f, 0.1f, 0.2f, 0.4f, 0.1f, 0.1f, 0.2f, 0.1f},
|
||||
{0, 0, 0, 0, 0, 1, 0, 0, 1, 1}),
|
||||
0.2908445f, 0.001f);
|
||||
EXPECT_NEAR(GetMetricEval(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user