diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 54cf920a7..3024b589f 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -209,7 +209,6 @@ inline void LaunchN(int device_idx, size_t n, cudaStream_t stream, L lambda) { if (n == 0) { return; } - safe_cuda(cudaSetDevice(device_idx)); const int GRID_SIZE = static_cast(xgboost::common::DivRoundUp(n, ITEMS_PER_THREAD * BLOCK_THREADS)); LaunchNKernel<<>>( // NOLINT @@ -368,6 +367,7 @@ struct XGBCachingDeviceAllocatorImpl : thrust::device_malloc_allocator { GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T)); GetGlobalCachingAllocator().DeviceFree(ptr.get()); } + __host__ __device__ void construct(T *) // NOLINT { @@ -391,6 +391,24 @@ using device_vector = thrust::device_vector>; // NOLI template using caching_device_vector = thrust::device_vector>; // NOLINT +// Faster to instantiate than caching_device_vector and invokes no synchronisation +// Use this where vector functionality (e.g. resize) is not required +template +class TemporaryArray { + public: + using AllocT = XGBCachingDeviceAllocator; + using value_type = T; // NOLINT + explicit TemporaryArray(size_t n) : size_(n) { ptr_ = AllocT().allocate(n); } + ~TemporaryArray() { AllocT().deallocate(ptr_, this->size()); } + + thrust::device_ptr data() { return ptr_; } // NOLINT + size_t size() { return size_; } // NOLINT + + private: + thrust::device_ptr ptr_; + size_t size_; +}; + /** * \brief A double buffer, useful for algorithms like sort. */ @@ -474,57 +492,6 @@ struct PinnedMemory { } }; -// Keep track of cub library device allocation -struct CubMemory { - void *d_temp_storage { nullptr }; - size_t temp_storage_bytes { 0 }; - - // Thrust - using value_type = char; // NOLINT - - CubMemory() = default; - - ~CubMemory() { Free(); } - - template - xgboost::common::Span GetSpan(size_t size) { - this->LazyAllocate(size * sizeof(T)); - return xgboost::common::Span(static_cast(d_temp_storage), size); - } - - void Free() { - if (this->IsAllocated()) { - XGBDeviceAllocator allocator; - allocator.deallocate(thrust::device_ptr(static_cast(d_temp_storage)), - temp_storage_bytes); - d_temp_storage = nullptr; - temp_storage_bytes = 0; - } - } - - void LazyAllocate(size_t num_bytes) { - if (num_bytes > temp_storage_bytes) { - Free(); - XGBDeviceAllocator allocator; - d_temp_storage = static_cast(allocator.allocate(num_bytes).get()); - temp_storage_bytes = num_bytes; - } - } - // Thrust - char *allocate(std::ptrdiff_t num_bytes) { // NOLINT - LazyAllocate(num_bytes); - return reinterpret_cast(d_temp_storage); - } - - // Thrust - void deallocate(char *ptr, size_t n) { // NOLINT - - // Do nothing - } - - bool IsAllocated() { return d_temp_storage != nullptr; } -}; - /* * Utility functions */ @@ -532,26 +499,24 @@ struct CubMemory { /** * @brief Helper function to perform device-wide sum-reduction, returns to the * host -* @param tmp_mem cub temporary memory info * @param in the input array to be reduced * @param nVals number of elements in the input array */ template -typename std::iterator_traits::value_type SumReduction( - dh::CubMemory* tmp_mem, T in, int nVals) { +typename std::iterator_traits::value_type SumReduction(T in, int nVals) { using ValueT = typename std::iterator_traits::value_type; size_t tmpSize {0}; ValueT *dummy_out = nullptr; dh::safe_cuda(cub::DeviceReduce::Sum(nullptr, tmpSize, in, dummy_out, nVals)); - // Allocate small extra memory for the return value - tmp_mem->LazyAllocate(tmpSize + sizeof(ValueT)); - auto ptr = reinterpret_cast(tmp_mem->d_temp_storage) + 1; + + TemporaryArray temp(tmpSize + sizeof(ValueT)); + auto ptr = reinterpret_cast(temp.data().get()) + 1; dh::safe_cuda(cub::DeviceReduce::Sum( reinterpret_cast(ptr), tmpSize, in, - reinterpret_cast(tmp_mem->d_temp_storage), + reinterpret_cast(temp.data().get()), nVals)); ValueT sum; - dh::safe_cuda(cudaMemcpy(&sum, tmp_mem->d_temp_storage, sizeof(ValueT), + dh::safe_cuda(cudaMemcpy(&sum, temp.data().get(), sizeof(ValueT), cudaMemcpyDeviceToHost)); return sum; } diff --git a/src/linear/updater_gpu_coordinate.cu b/src/linear/updater_gpu_coordinate.cu index 6337bc2ca..e7db2dc02 100644 --- a/src/linear/updater_gpu_coordinate.cu +++ b/src/linear/updater_gpu_coordinate.cu @@ -185,7 +185,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT counting, f); auto perm = thrust::make_permutation_iterator(gpair_.data(), skip); - return dh::SumReduction(&temp_, perm, num_row_); + return dh::SumReduction(perm, num_row_); } // This needs to be public because of the __device__ lambda. @@ -213,7 +213,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT }; // NOLINT thrust::transform_iterator multiply_iterator(counting, f); - return dh::SumReduction(&temp_, multiply_iterator, col_size); + return dh::SumReduction(multiply_iterator, col_size); } // This needs to be public because of the __device__ lambda. @@ -249,7 +249,6 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT std::vector row_ptr_; dh::device_vector data_; dh::caching_device_vector gpair_; - dh::CubMemory temp_; size_t num_row_; }; diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index 72381d5d6..df6e665d3 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -59,13 +59,6 @@ class ElementWiseMetricsReduction { #if defined(XGBOOST_USE_CUDA) - ~ElementWiseMetricsReduction() { - if (device_ >= 0) { - dh::safe_cuda(cudaSetDevice(device_)); - allocator_.Free(); - } - } - PackedReduceResult DeviceReduceMetrics( const HostDeviceVector& weights, const HostDeviceVector& labels, @@ -83,8 +76,9 @@ class ElementWiseMetricsReduction { auto d_policy = policy_; + dh::XGBCachingDeviceAllocator alloc; PackedReduceResult result = thrust::transform_reduce( - thrust::cuda::par(allocator_), + thrust::cuda::par(alloc), begin, end, [=] XGBOOST_DEVICE(size_t idx) { bst_float weight = is_null_weight ? 1.0f : s_weights[idx]; @@ -130,7 +124,6 @@ class ElementWiseMetricsReduction { EvalRow policy_; #if defined(XGBOOST_USE_CUDA) int device_{-1}; - dh::CubMemory allocator_; #endif // defined(XGBOOST_USE_CUDA) }; diff --git a/src/metric/multiclass_metric.cu b/src/metric/multiclass_metric.cu index 0b0dc709f..377a05010 100644 --- a/src/metric/multiclass_metric.cu +++ b/src/metric/multiclass_metric.cu @@ -73,13 +73,6 @@ class MultiClassMetricsReduction { #if defined(XGBOOST_USE_CUDA) - ~MultiClassMetricsReduction() { - if (device_ >= 0) { - dh::safe_cuda(cudaSetDevice(device_)); - allocator_.Free(); - } - } - PackedReduceResult DeviceReduceMetrics( const HostDeviceVector& weights, const HostDeviceVector& labels, @@ -98,8 +91,9 @@ class MultiClassMetricsReduction { auto s_label_error = label_error_.GetSpan(1); s_label_error[0] = 0; + dh::XGBCachingDeviceAllocator alloc; PackedReduceResult result = thrust::transform_reduce( - thrust::cuda::par(allocator_), + thrust::cuda::par(alloc), begin, end, [=] XGBOOST_DEVICE(size_t idx) { bst_float weight = is_null_weight ? 1.0f : s_weights[idx]; @@ -152,7 +146,6 @@ class MultiClassMetricsReduction { #if defined(XGBOOST_USE_CUDA) dh::PinnedMemory label_error_; int device_{-1}; - dh::CubMemory allocator_; #endif // defined(XGBOOST_USE_CUDA) }; diff --git a/src/tree/gpu_hist/row_partitioner.cuh b/src/tree/gpu_hist/row_partitioner.cuh index 03334efe6..7a6861674 100644 --- a/src/tree/gpu_hist/row_partitioner.cuh +++ b/src/tree/gpu_hist/row_partitioner.cuh @@ -108,7 +108,6 @@ class RowPartitioner { template void UpdatePosition(bst_node_t nidx, bst_node_t left_nidx, bst_node_t right_nidx, UpdatePositionOpT op) { - dh::safe_cuda(cudaSetDevice(device_idx_)); Segment segment = ridx_segments_.at(nidx); // rows belongs to node nidx auto d_ridx = ridx_.CurrentSpan(); auto d_position = position_.CurrentSpan(); diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index e9482b203..986e43fe5 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -2,9 +2,6 @@ * Copyright 2017-2020 XGBoost contributors */ #include -#include -#include -#include #include #include #include @@ -20,8 +17,6 @@ #include "xgboost/span.h" #include "xgboost/json.h" -#include "../common/common.h" -#include "../common/compressed_iterator.h" #include "../common/device_helpers.cuh" #include "../common/hist_util.h" #include "../common/timer.h" @@ -324,9 +319,9 @@ class DeviceHistogram { } void Reset() { - dh::safe_cuda(cudaMemsetAsync( - data_.data().get(), 0, - data_.size() * sizeof(typename decltype(data_)::value_type))); + auto d_data = data_.data().get(); + dh::LaunchN(device_id_, data_.size(), + [=] __device__(size_t idx) { d_data[idx] = 0.0f; }); nidx_map_.clear(); } bool HistogramExists(int nidx) const { @@ -348,30 +343,33 @@ class DeviceHistogram { // Number of items currently used in data const size_t used_size = nidx_map_.size() * HistogramSize(); const size_t new_used_size = used_size + HistogramSize(); - dh::safe_cuda(cudaSetDevice(device_id_)); if (data_.size() >= kStopGrowingSize) { // Recycle histogram memory if (new_used_size <= data_.size()) { // no need to remove old node, just insert the new one. nidx_map_[nidx] = used_size; // memset histogram size in bytes - dh::safe_cuda(cudaMemsetAsync(data_.data().get() + used_size, 0, - n_bins_ * sizeof(GradientSumT))); } else { std::pair old_entry = *nidx_map_.begin(); nidx_map_.erase(old_entry.first); - dh::safe_cuda(cudaMemsetAsync(data_.data().get() + old_entry.second, 0, - n_bins_ * sizeof(GradientSumT))); nidx_map_[nidx] = old_entry.second; } + // Zero recycled memory + auto d_data = data_.data().get() + nidx_map_[nidx]; + dh::LaunchN(device_id_, n_bins_ * 2, + [=] __device__(size_t idx) { d_data[idx] = 0.0f; }); } else { // Append new node histogram nidx_map_[nidx] = used_size; - size_t new_required_memory = std::max(data_.size() * 2, HistogramSize()); - if (data_.size() < new_required_memory) { + // Check there is enough memory for another histogram node + if (data_.size() < new_used_size + HistogramSize()) { + size_t new_required_memory = + std::max(data_.size() * 2, HistogramSize()); data_.resize(new_required_memory); } } + + CHECK_GE(data_.size(), nidx_map_.size() * HistogramSize()); } /** @@ -428,7 +426,6 @@ struct GPUHistMakerDevice { GradientSumT histogram_rounding; - dh::CubMemory temp_memory; dh::PinnedMemory pinned_memory; std::vector streams{}; @@ -531,15 +528,14 @@ struct GPUHistMakerDevice { std::vector EvaluateSplits( std::vector nidxs, const RegTree& tree, size_t num_columns) { - dh::safe_cuda(cudaSetDevice(device_id)); auto result_all = pinned_memory.GetSpan(nidxs.size()); // Work out cub temporary memory requirement GPUTrainingParam gpu_param(param); DeviceSplitCandidateReduceOp op(gpu_param); - dh::caching_device_vector d_result_all(nidxs.size()); - dh::caching_device_vector split_candidates_all(nidxs.size()*num_columns); + dh::TemporaryArray d_result_all(nidxs.size()); + dh::TemporaryArray split_candidates_all(nidxs.size()*num_columns); auto& streams = this->GetStreams(nidxs.size()); for (auto i = 0ull; i < nidxs.size(); i++) { @@ -582,7 +578,7 @@ struct GPUHistMakerDevice { cub_bytes, d_split_candidates.data(), d_result.data(), d_split_candidates.size(), op, DeviceSplitCandidate(), streams[i]); - dh::caching_device_vector cub_temp(cub_bytes); + dh::TemporaryArray cub_temp(cub_bytes); cub::DeviceReduce::Reduce(reinterpret_cast(cub_temp.data().get()), cub_bytes, d_split_candidates.data(), d_result.data(), d_split_candidates.size(), op, @@ -651,9 +647,8 @@ struct GPUHistMakerDevice { // instances to their final leaf. This information is used later to update the // prediction cache void FinalisePosition(RegTree const* p_tree, DMatrix* p_fmat) { - const auto d_nodes = - temp_memory.GetSpan(p_tree->GetNodes().size()); - dh::safe_cuda(cudaMemcpy(d_nodes.data(), p_tree->GetNodes().data(), + dh::TemporaryArray d_nodes(p_tree->GetNodes().size()); + dh::safe_cuda(cudaMemcpy(d_nodes.data().get(), p_tree->GetNodes().data(), d_nodes.size() * sizeof(RegTree::Node), cudaMemcpyHostToDevice)); @@ -662,10 +657,10 @@ struct GPUHistMakerDevice { row_partitioner.reset(new RowPartitioner(device_id, p_fmat->Info().num_row_)); } if (page->n_rows == p_fmat->Info().num_row_) { - FinalisePositionInPage(page, d_nodes); + FinalisePositionInPage(page, dh::ToSpan(d_nodes)); } else { for (auto& batch : p_fmat->GetBatches(batch_param)) { - FinalisePositionInPage(batch.Impl(), d_nodes); + FinalisePositionInPage(batch.Impl(), dh::ToSpan(d_nodes)); } } } diff --git a/tests/cpp/common/test_device_helpers.cu b/tests/cpp/common/test_device_helpers.cu index b10ea0c53..52436c025 100644 --- a/tests/cpp/common/test_device_helpers.cu +++ b/tests/cpp/common/test_device_helpers.cu @@ -10,8 +10,7 @@ TEST(SumReduce, Test) { thrust::device_vector data(100, 1.0f); - dh::CubMemory temp; - auto sum = dh::SumReduction(&temp, data.data().get(), data.size()); + auto sum = dh::SumReduction(data.data().get(), data.size()); ASSERT_NEAR(sum, 100.0f, 1e-5); }