gpu_hist performance fixes (#5558)
* Remove unnecessary cuda API calls * Fix histogram memory growth
This commit is contained in:
parent
e1f22baf8c
commit
d6d1035950
@ -209,7 +209,6 @@ inline void LaunchN(int device_idx, size_t n, cudaStream_t stream, L lambda) {
|
|||||||
if (n == 0) {
|
if (n == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
safe_cuda(cudaSetDevice(device_idx));
|
|
||||||
const int GRID_SIZE =
|
const int GRID_SIZE =
|
||||||
static_cast<int>(xgboost::common::DivRoundUp(n, ITEMS_PER_THREAD * BLOCK_THREADS));
|
static_cast<int>(xgboost::common::DivRoundUp(n, ITEMS_PER_THREAD * BLOCK_THREADS));
|
||||||
LaunchNKernel<<<GRID_SIZE, BLOCK_THREADS, 0, stream>>>( // NOLINT
|
LaunchNKernel<<<GRID_SIZE, BLOCK_THREADS, 0, stream>>>( // NOLINT
|
||||||
@ -368,6 +367,7 @@ struct XGBCachingDeviceAllocatorImpl : thrust::device_malloc_allocator<T> {
|
|||||||
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T));
|
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T));
|
||||||
GetGlobalCachingAllocator().DeviceFree(ptr.get());
|
GetGlobalCachingAllocator().DeviceFree(ptr.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
__host__ __device__
|
__host__ __device__
|
||||||
void construct(T *) // NOLINT
|
void construct(T *) // NOLINT
|
||||||
{
|
{
|
||||||
@ -391,6 +391,24 @@ using device_vector = thrust::device_vector<T, XGBDeviceAllocator<T>>; // NOLI
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
using caching_device_vector = thrust::device_vector<T, XGBCachingDeviceAllocator<T>>; // NOLINT
|
using caching_device_vector = thrust::device_vector<T, XGBCachingDeviceAllocator<T>>; // NOLINT
|
||||||
|
|
||||||
|
// Faster to instantiate than caching_device_vector and invokes no synchronisation
|
||||||
|
// Use this where vector functionality (e.g. resize) is not required
|
||||||
|
template <typename T>
|
||||||
|
class TemporaryArray {
|
||||||
|
public:
|
||||||
|
using AllocT = XGBCachingDeviceAllocator<T>;
|
||||||
|
using value_type = T; // NOLINT
|
||||||
|
explicit TemporaryArray(size_t n) : size_(n) { ptr_ = AllocT().allocate(n); }
|
||||||
|
~TemporaryArray() { AllocT().deallocate(ptr_, this->size()); }
|
||||||
|
|
||||||
|
thrust::device_ptr<T> data() { return ptr_; } // NOLINT
|
||||||
|
size_t size() { return size_; } // NOLINT
|
||||||
|
|
||||||
|
private:
|
||||||
|
thrust::device_ptr<T> ptr_;
|
||||||
|
size_t size_;
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief A double buffer, useful for algorithms like sort.
|
* \brief A double buffer, useful for algorithms like sort.
|
||||||
*/
|
*/
|
||||||
@ -474,57 +492,6 @@ struct PinnedMemory {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Keep track of cub library device allocation
|
|
||||||
struct CubMemory {
|
|
||||||
void *d_temp_storage { nullptr };
|
|
||||||
size_t temp_storage_bytes { 0 };
|
|
||||||
|
|
||||||
// Thrust
|
|
||||||
using value_type = char; // NOLINT
|
|
||||||
|
|
||||||
CubMemory() = default;
|
|
||||||
|
|
||||||
~CubMemory() { Free(); }
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
xgboost::common::Span<T> GetSpan(size_t size) {
|
|
||||||
this->LazyAllocate(size * sizeof(T));
|
|
||||||
return xgboost::common::Span<T>(static_cast<T*>(d_temp_storage), size);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Free() {
|
|
||||||
if (this->IsAllocated()) {
|
|
||||||
XGBDeviceAllocator<uint8_t> allocator;
|
|
||||||
allocator.deallocate(thrust::device_ptr<uint8_t>(static_cast<uint8_t *>(d_temp_storage)),
|
|
||||||
temp_storage_bytes);
|
|
||||||
d_temp_storage = nullptr;
|
|
||||||
temp_storage_bytes = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void LazyAllocate(size_t num_bytes) {
|
|
||||||
if (num_bytes > temp_storage_bytes) {
|
|
||||||
Free();
|
|
||||||
XGBDeviceAllocator<uint8_t> allocator;
|
|
||||||
d_temp_storage = static_cast<void *>(allocator.allocate(num_bytes).get());
|
|
||||||
temp_storage_bytes = num_bytes;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Thrust
|
|
||||||
char *allocate(std::ptrdiff_t num_bytes) { // NOLINT
|
|
||||||
LazyAllocate(num_bytes);
|
|
||||||
return reinterpret_cast<char *>(d_temp_storage);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Thrust
|
|
||||||
void deallocate(char *ptr, size_t n) { // NOLINT
|
|
||||||
|
|
||||||
// Do nothing
|
|
||||||
}
|
|
||||||
|
|
||||||
bool IsAllocated() { return d_temp_storage != nullptr; }
|
|
||||||
};
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Utility functions
|
* Utility functions
|
||||||
*/
|
*/
|
||||||
@ -532,26 +499,24 @@ struct CubMemory {
|
|||||||
/**
|
/**
|
||||||
* @brief Helper function to perform device-wide sum-reduction, returns to the
|
* @brief Helper function to perform device-wide sum-reduction, returns to the
|
||||||
* host
|
* host
|
||||||
* @param tmp_mem cub temporary memory info
|
|
||||||
* @param in the input array to be reduced
|
* @param in the input array to be reduced
|
||||||
* @param nVals number of elements in the input array
|
* @param nVals number of elements in the input array
|
||||||
*/
|
*/
|
||||||
template <typename T>
|
template <typename T>
|
||||||
typename std::iterator_traits<T>::value_type SumReduction(
|
typename std::iterator_traits<T>::value_type SumReduction(T in, int nVals) {
|
||||||
dh::CubMemory* tmp_mem, T in, int nVals) {
|
|
||||||
using ValueT = typename std::iterator_traits<T>::value_type;
|
using ValueT = typename std::iterator_traits<T>::value_type;
|
||||||
size_t tmpSize {0};
|
size_t tmpSize {0};
|
||||||
ValueT *dummy_out = nullptr;
|
ValueT *dummy_out = nullptr;
|
||||||
dh::safe_cuda(cub::DeviceReduce::Sum(nullptr, tmpSize, in, dummy_out, nVals));
|
dh::safe_cuda(cub::DeviceReduce::Sum(nullptr, tmpSize, in, dummy_out, nVals));
|
||||||
// Allocate small extra memory for the return value
|
|
||||||
tmp_mem->LazyAllocate(tmpSize + sizeof(ValueT));
|
TemporaryArray<char> temp(tmpSize + sizeof(ValueT));
|
||||||
auto ptr = reinterpret_cast<ValueT *>(tmp_mem->d_temp_storage) + 1;
|
auto ptr = reinterpret_cast<ValueT *>(temp.data().get()) + 1;
|
||||||
dh::safe_cuda(cub::DeviceReduce::Sum(
|
dh::safe_cuda(cub::DeviceReduce::Sum(
|
||||||
reinterpret_cast<void *>(ptr), tmpSize, in,
|
reinterpret_cast<void *>(ptr), tmpSize, in,
|
||||||
reinterpret_cast<ValueT *>(tmp_mem->d_temp_storage),
|
reinterpret_cast<ValueT *>(temp.data().get()),
|
||||||
nVals));
|
nVals));
|
||||||
ValueT sum;
|
ValueT sum;
|
||||||
dh::safe_cuda(cudaMemcpy(&sum, tmp_mem->d_temp_storage, sizeof(ValueT),
|
dh::safe_cuda(cudaMemcpy(&sum, temp.data().get(), sizeof(ValueT),
|
||||||
cudaMemcpyDeviceToHost));
|
cudaMemcpyDeviceToHost));
|
||||||
return sum;
|
return sum;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -185,7 +185,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
|||||||
counting, f);
|
counting, f);
|
||||||
auto perm = thrust::make_permutation_iterator(gpair_.data(), skip);
|
auto perm = thrust::make_permutation_iterator(gpair_.data(), skip);
|
||||||
|
|
||||||
return dh::SumReduction(&temp_, perm, num_row_);
|
return dh::SumReduction(perm, num_row_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// This needs to be public because of the __device__ lambda.
|
// This needs to be public because of the __device__ lambda.
|
||||||
@ -213,7 +213,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
|||||||
}; // NOLINT
|
}; // NOLINT
|
||||||
thrust::transform_iterator<decltype(f), decltype(counting), GradientPair>
|
thrust::transform_iterator<decltype(f), decltype(counting), GradientPair>
|
||||||
multiply_iterator(counting, f);
|
multiply_iterator(counting, f);
|
||||||
return dh::SumReduction(&temp_, multiply_iterator, col_size);
|
return dh::SumReduction(multiply_iterator, col_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
// This needs to be public because of the __device__ lambda.
|
// This needs to be public because of the __device__ lambda.
|
||||||
@ -249,7 +249,6 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
|||||||
std::vector<size_t> row_ptr_;
|
std::vector<size_t> row_ptr_;
|
||||||
dh::device_vector<xgboost::Entry> data_;
|
dh::device_vector<xgboost::Entry> data_;
|
||||||
dh::caching_device_vector<GradientPair> gpair_;
|
dh::caching_device_vector<GradientPair> gpair_;
|
||||||
dh::CubMemory temp_;
|
|
||||||
size_t num_row_;
|
size_t num_row_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -59,13 +59,6 @@ class ElementWiseMetricsReduction {
|
|||||||
|
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
|
|
||||||
~ElementWiseMetricsReduction() {
|
|
||||||
if (device_ >= 0) {
|
|
||||||
dh::safe_cuda(cudaSetDevice(device_));
|
|
||||||
allocator_.Free();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
PackedReduceResult DeviceReduceMetrics(
|
PackedReduceResult DeviceReduceMetrics(
|
||||||
const HostDeviceVector<bst_float>& weights,
|
const HostDeviceVector<bst_float>& weights,
|
||||||
const HostDeviceVector<bst_float>& labels,
|
const HostDeviceVector<bst_float>& labels,
|
||||||
@ -83,8 +76,9 @@ class ElementWiseMetricsReduction {
|
|||||||
|
|
||||||
auto d_policy = policy_;
|
auto d_policy = policy_;
|
||||||
|
|
||||||
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
PackedReduceResult result = thrust::transform_reduce(
|
PackedReduceResult result = thrust::transform_reduce(
|
||||||
thrust::cuda::par(allocator_),
|
thrust::cuda::par(alloc),
|
||||||
begin, end,
|
begin, end,
|
||||||
[=] XGBOOST_DEVICE(size_t idx) {
|
[=] XGBOOST_DEVICE(size_t idx) {
|
||||||
bst_float weight = is_null_weight ? 1.0f : s_weights[idx];
|
bst_float weight = is_null_weight ? 1.0f : s_weights[idx];
|
||||||
@ -130,7 +124,6 @@ class ElementWiseMetricsReduction {
|
|||||||
EvalRow policy_;
|
EvalRow policy_;
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
int device_{-1};
|
int device_{-1};
|
||||||
dh::CubMemory allocator_;
|
|
||||||
#endif // defined(XGBOOST_USE_CUDA)
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -73,13 +73,6 @@ class MultiClassMetricsReduction {
|
|||||||
|
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
|
|
||||||
~MultiClassMetricsReduction() {
|
|
||||||
if (device_ >= 0) {
|
|
||||||
dh::safe_cuda(cudaSetDevice(device_));
|
|
||||||
allocator_.Free();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
PackedReduceResult DeviceReduceMetrics(
|
PackedReduceResult DeviceReduceMetrics(
|
||||||
const HostDeviceVector<bst_float>& weights,
|
const HostDeviceVector<bst_float>& weights,
|
||||||
const HostDeviceVector<bst_float>& labels,
|
const HostDeviceVector<bst_float>& labels,
|
||||||
@ -98,8 +91,9 @@ class MultiClassMetricsReduction {
|
|||||||
auto s_label_error = label_error_.GetSpan<int32_t>(1);
|
auto s_label_error = label_error_.GetSpan<int32_t>(1);
|
||||||
s_label_error[0] = 0;
|
s_label_error[0] = 0;
|
||||||
|
|
||||||
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
PackedReduceResult result = thrust::transform_reduce(
|
PackedReduceResult result = thrust::transform_reduce(
|
||||||
thrust::cuda::par(allocator_),
|
thrust::cuda::par(alloc),
|
||||||
begin, end,
|
begin, end,
|
||||||
[=] XGBOOST_DEVICE(size_t idx) {
|
[=] XGBOOST_DEVICE(size_t idx) {
|
||||||
bst_float weight = is_null_weight ? 1.0f : s_weights[idx];
|
bst_float weight = is_null_weight ? 1.0f : s_weights[idx];
|
||||||
@ -152,7 +146,6 @@ class MultiClassMetricsReduction {
|
|||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
dh::PinnedMemory label_error_;
|
dh::PinnedMemory label_error_;
|
||||||
int device_{-1};
|
int device_{-1};
|
||||||
dh::CubMemory allocator_;
|
|
||||||
#endif // defined(XGBOOST_USE_CUDA)
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -108,7 +108,6 @@ class RowPartitioner {
|
|||||||
template <typename UpdatePositionOpT>
|
template <typename UpdatePositionOpT>
|
||||||
void UpdatePosition(bst_node_t nidx, bst_node_t left_nidx,
|
void UpdatePosition(bst_node_t nidx, bst_node_t left_nidx,
|
||||||
bst_node_t right_nidx, UpdatePositionOpT op) {
|
bst_node_t right_nidx, UpdatePositionOpT op) {
|
||||||
dh::safe_cuda(cudaSetDevice(device_idx_));
|
|
||||||
Segment segment = ridx_segments_.at(nidx); // rows belongs to node nidx
|
Segment segment = ridx_segments_.at(nidx); // rows belongs to node nidx
|
||||||
auto d_ridx = ridx_.CurrentSpan();
|
auto d_ridx = ridx_.CurrentSpan();
|
||||||
auto d_position = position_.CurrentSpan();
|
auto d_position = position_.CurrentSpan();
|
||||||
|
|||||||
@ -2,9 +2,6 @@
|
|||||||
* Copyright 2017-2020 XGBoost contributors
|
* Copyright 2017-2020 XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#include <thrust/copy.h>
|
#include <thrust/copy.h>
|
||||||
#include <thrust/functional.h>
|
|
||||||
#include <thrust/iterator/counting_iterator.h>
|
|
||||||
#include <thrust/iterator/transform_iterator.h>
|
|
||||||
#include <thrust/reduce.h>
|
#include <thrust/reduce.h>
|
||||||
#include <xgboost/tree_updater.h>
|
#include <xgboost/tree_updater.h>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
@ -20,8 +17,6 @@
|
|||||||
#include "xgboost/span.h"
|
#include "xgboost/span.h"
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
|
|
||||||
#include "../common/common.h"
|
|
||||||
#include "../common/compressed_iterator.h"
|
|
||||||
#include "../common/device_helpers.cuh"
|
#include "../common/device_helpers.cuh"
|
||||||
#include "../common/hist_util.h"
|
#include "../common/hist_util.h"
|
||||||
#include "../common/timer.h"
|
#include "../common/timer.h"
|
||||||
@ -324,9 +319,9 @@ class DeviceHistogram {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Reset() {
|
void Reset() {
|
||||||
dh::safe_cuda(cudaMemsetAsync(
|
auto d_data = data_.data().get();
|
||||||
data_.data().get(), 0,
|
dh::LaunchN(device_id_, data_.size(),
|
||||||
data_.size() * sizeof(typename decltype(data_)::value_type)));
|
[=] __device__(size_t idx) { d_data[idx] = 0.0f; });
|
||||||
nidx_map_.clear();
|
nidx_map_.clear();
|
||||||
}
|
}
|
||||||
bool HistogramExists(int nidx) const {
|
bool HistogramExists(int nidx) const {
|
||||||
@ -348,30 +343,33 @@ class DeviceHistogram {
|
|||||||
// Number of items currently used in data
|
// Number of items currently used in data
|
||||||
const size_t used_size = nidx_map_.size() * HistogramSize();
|
const size_t used_size = nidx_map_.size() * HistogramSize();
|
||||||
const size_t new_used_size = used_size + HistogramSize();
|
const size_t new_used_size = used_size + HistogramSize();
|
||||||
dh::safe_cuda(cudaSetDevice(device_id_));
|
|
||||||
if (data_.size() >= kStopGrowingSize) {
|
if (data_.size() >= kStopGrowingSize) {
|
||||||
// Recycle histogram memory
|
// Recycle histogram memory
|
||||||
if (new_used_size <= data_.size()) {
|
if (new_used_size <= data_.size()) {
|
||||||
// no need to remove old node, just insert the new one.
|
// no need to remove old node, just insert the new one.
|
||||||
nidx_map_[nidx] = used_size;
|
nidx_map_[nidx] = used_size;
|
||||||
// memset histogram size in bytes
|
// memset histogram size in bytes
|
||||||
dh::safe_cuda(cudaMemsetAsync(data_.data().get() + used_size, 0,
|
|
||||||
n_bins_ * sizeof(GradientSumT)));
|
|
||||||
} else {
|
} else {
|
||||||
std::pair<int, size_t> old_entry = *nidx_map_.begin();
|
std::pair<int, size_t> old_entry = *nidx_map_.begin();
|
||||||
nidx_map_.erase(old_entry.first);
|
nidx_map_.erase(old_entry.first);
|
||||||
dh::safe_cuda(cudaMemsetAsync(data_.data().get() + old_entry.second, 0,
|
|
||||||
n_bins_ * sizeof(GradientSumT)));
|
|
||||||
nidx_map_[nidx] = old_entry.second;
|
nidx_map_[nidx] = old_entry.second;
|
||||||
}
|
}
|
||||||
|
// Zero recycled memory
|
||||||
|
auto d_data = data_.data().get() + nidx_map_[nidx];
|
||||||
|
dh::LaunchN(device_id_, n_bins_ * 2,
|
||||||
|
[=] __device__(size_t idx) { d_data[idx] = 0.0f; });
|
||||||
} else {
|
} else {
|
||||||
// Append new node histogram
|
// Append new node histogram
|
||||||
nidx_map_[nidx] = used_size;
|
nidx_map_[nidx] = used_size;
|
||||||
size_t new_required_memory = std::max(data_.size() * 2, HistogramSize());
|
// Check there is enough memory for another histogram node
|
||||||
if (data_.size() < new_required_memory) {
|
if (data_.size() < new_used_size + HistogramSize()) {
|
||||||
|
size_t new_required_memory =
|
||||||
|
std::max(data_.size() * 2, HistogramSize());
|
||||||
data_.resize(new_required_memory);
|
data_.resize(new_required_memory);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CHECK_GE(data_.size(), nidx_map_.size() * HistogramSize());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -428,7 +426,6 @@ struct GPUHistMakerDevice {
|
|||||||
|
|
||||||
GradientSumT histogram_rounding;
|
GradientSumT histogram_rounding;
|
||||||
|
|
||||||
dh::CubMemory temp_memory;
|
|
||||||
dh::PinnedMemory pinned_memory;
|
dh::PinnedMemory pinned_memory;
|
||||||
|
|
||||||
std::vector<cudaStream_t> streams{};
|
std::vector<cudaStream_t> streams{};
|
||||||
@ -531,15 +528,14 @@ struct GPUHistMakerDevice {
|
|||||||
std::vector<DeviceSplitCandidate> EvaluateSplits(
|
std::vector<DeviceSplitCandidate> EvaluateSplits(
|
||||||
std::vector<int> nidxs, const RegTree& tree,
|
std::vector<int> nidxs, const RegTree& tree,
|
||||||
size_t num_columns) {
|
size_t num_columns) {
|
||||||
dh::safe_cuda(cudaSetDevice(device_id));
|
|
||||||
auto result_all = pinned_memory.GetSpan<DeviceSplitCandidate>(nidxs.size());
|
auto result_all = pinned_memory.GetSpan<DeviceSplitCandidate>(nidxs.size());
|
||||||
|
|
||||||
// Work out cub temporary memory requirement
|
// Work out cub temporary memory requirement
|
||||||
GPUTrainingParam gpu_param(param);
|
GPUTrainingParam gpu_param(param);
|
||||||
DeviceSplitCandidateReduceOp op(gpu_param);
|
DeviceSplitCandidateReduceOp op(gpu_param);
|
||||||
|
|
||||||
dh::caching_device_vector<DeviceSplitCandidate> d_result_all(nidxs.size());
|
dh::TemporaryArray<DeviceSplitCandidate> d_result_all(nidxs.size());
|
||||||
dh::caching_device_vector<DeviceSplitCandidate> split_candidates_all(nidxs.size()*num_columns);
|
dh::TemporaryArray<DeviceSplitCandidate> split_candidates_all(nidxs.size()*num_columns);
|
||||||
|
|
||||||
auto& streams = this->GetStreams(nidxs.size());
|
auto& streams = this->GetStreams(nidxs.size());
|
||||||
for (auto i = 0ull; i < nidxs.size(); i++) {
|
for (auto i = 0ull; i < nidxs.size(); i++) {
|
||||||
@ -582,7 +578,7 @@ struct GPUHistMakerDevice {
|
|||||||
cub_bytes, d_split_candidates.data(),
|
cub_bytes, d_split_candidates.data(),
|
||||||
d_result.data(), d_split_candidates.size(), op,
|
d_result.data(), d_split_candidates.size(), op,
|
||||||
DeviceSplitCandidate(), streams[i]);
|
DeviceSplitCandidate(), streams[i]);
|
||||||
dh::caching_device_vector<char> cub_temp(cub_bytes);
|
dh::TemporaryArray<char> cub_temp(cub_bytes);
|
||||||
cub::DeviceReduce::Reduce(reinterpret_cast<void*>(cub_temp.data().get()),
|
cub::DeviceReduce::Reduce(reinterpret_cast<void*>(cub_temp.data().get()),
|
||||||
cub_bytes, d_split_candidates.data(),
|
cub_bytes, d_split_candidates.data(),
|
||||||
d_result.data(), d_split_candidates.size(), op,
|
d_result.data(), d_split_candidates.size(), op,
|
||||||
@ -651,9 +647,8 @@ struct GPUHistMakerDevice {
|
|||||||
// instances to their final leaf. This information is used later to update the
|
// instances to their final leaf. This information is used later to update the
|
||||||
// prediction cache
|
// prediction cache
|
||||||
void FinalisePosition(RegTree const* p_tree, DMatrix* p_fmat) {
|
void FinalisePosition(RegTree const* p_tree, DMatrix* p_fmat) {
|
||||||
const auto d_nodes =
|
dh::TemporaryArray<RegTree::Node> d_nodes(p_tree->GetNodes().size());
|
||||||
temp_memory.GetSpan<RegTree::Node>(p_tree->GetNodes().size());
|
dh::safe_cuda(cudaMemcpy(d_nodes.data().get(), p_tree->GetNodes().data(),
|
||||||
dh::safe_cuda(cudaMemcpy(d_nodes.data(), p_tree->GetNodes().data(),
|
|
||||||
d_nodes.size() * sizeof(RegTree::Node),
|
d_nodes.size() * sizeof(RegTree::Node),
|
||||||
cudaMemcpyHostToDevice));
|
cudaMemcpyHostToDevice));
|
||||||
|
|
||||||
@ -662,10 +657,10 @@ struct GPUHistMakerDevice {
|
|||||||
row_partitioner.reset(new RowPartitioner(device_id, p_fmat->Info().num_row_));
|
row_partitioner.reset(new RowPartitioner(device_id, p_fmat->Info().num_row_));
|
||||||
}
|
}
|
||||||
if (page->n_rows == p_fmat->Info().num_row_) {
|
if (page->n_rows == p_fmat->Info().num_row_) {
|
||||||
FinalisePositionInPage(page, d_nodes);
|
FinalisePositionInPage(page, dh::ToSpan(d_nodes));
|
||||||
} else {
|
} else {
|
||||||
for (auto& batch : p_fmat->GetBatches<EllpackPage>(batch_param)) {
|
for (auto& batch : p_fmat->GetBatches<EllpackPage>(batch_param)) {
|
||||||
FinalisePositionInPage(batch.Impl(), d_nodes);
|
FinalisePositionInPage(batch.Impl(), dh::ToSpan(d_nodes));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -10,8 +10,7 @@
|
|||||||
|
|
||||||
TEST(SumReduce, Test) {
|
TEST(SumReduce, Test) {
|
||||||
thrust::device_vector<float> data(100, 1.0f);
|
thrust::device_vector<float> data(100, 1.0f);
|
||||||
dh::CubMemory temp;
|
auto sum = dh::SumReduction(data.data().get(), data.size());
|
||||||
auto sum = dh::SumReduction(&temp, data.data().get(), data.size());
|
|
||||||
ASSERT_NEAR(sum, 100.0f, 1e-5);
|
ASSERT_NEAR(sum, 100.0f, 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user