From 9683fd433e7a50c73be5c41a4e4f6075dcefb56e Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Mon, 10 Jun 2019 11:35:13 +1200 Subject: [PATCH] Overload device memory allocation (#4532) * Group source files, include headers in source files * Overload device memory allocation --- src/CMakeLists.txt | 8 ++- src/common/device_helpers.cuh | 114 ++++++++++++++++++++++++++---- src/common/hist_util.cu | 20 +++--- src/common/host_device_vector.cu | 2 +- src/predictor/gpu_predictor.cu | 14 ++-- src/tree/updater_gpu_hist.cu | 14 ++-- tests/benchmark/benchmark_tree.py | 2 +- tests/cpp/CMakeLists.txt | 3 + tests/cpp/tree/test_gpu_hist.cu | 12 +--- 9 files changed, 140 insertions(+), 49 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c053321b6..a215fcef5 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,4 +1,4 @@ -file(GLOB_RECURSE CPU_SOURCES *.cc) +file(GLOB_RECURSE CPU_SOURCES *.cc *.h) list(REMOVE_ITEM CPU_SOURCES ${PROJECT_SOURCE_DIR}/src/cli_main.cc) include(CheckCXXSourceCompiles) @@ -33,7 +33,7 @@ endif (PLUGIN_DENSE_PARSER) # Object library is necessary for jvm-package, which creates its own shared # library. if (USE_CUDA) - file(GLOB_RECURSE CUDA_SOURCES *.cu) + file(GLOB_RECURSE CUDA_SOURCES *.cu *.cuh) add_library(objxgboost OBJECT ${CPU_SOURCES} ${CUDA_SOURCES} ${PLUGINS_SOURCES}) target_compile_definitions(objxgboost PRIVATE -DXGBOOST_USE_CUDA=1) @@ -119,4 +119,8 @@ endif (USE_OPENMP) # for issues caused by mixing of /MD and /MT flags msvc_use_static_runtime() +# This grouping organises source files nicely in visual studio +auto_source_group("${CUDA_SOURCES}") +auto_source_group("${CPU_SOURCES}") + #-- End object library diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index c849730dc..1484292f9 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -4,6 +4,7 @@ #pragma once #include #include +#include #include #include #include @@ -49,11 +50,6 @@ inline ncclResult_t ThrowOnNcclError(ncclResult_t code, const char *file, } #endif -template -T *Raw(thrust::device_vector &v) { // NOLINT - return raw_pointer_cast(v.data()); -} - inline void CudaCheckPointerDevice(void* ptr) { cudaPointerAttributes attr; dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr)); @@ -225,6 +221,97 @@ inline void LaunchN(int device_idx, size_t n, L lambda) { LaunchN(device_idx, n, nullptr, lambda); } +namespace detail { +/** \brief Keeps track of global device memory allocations. Thread safe.*/ +class MemoryLogger { + // Information for a single device + struct DeviceStats { + size_t currently_allocated_bytes{ 0 }; + size_t peak_allocated_bytes{ 0 }; + size_t num_allocations{ 0 }; + size_t num_deallocations{ 0 }; + std::map device_allocations; + void RegisterAllocation(void *ptr, size_t n) { + device_allocations[ptr] = n; + currently_allocated_bytes += n; + peak_allocated_bytes = + std::max(peak_allocated_bytes, currently_allocated_bytes); + num_allocations++; + } + void RegisterDeallocation(void *ptr) { + num_deallocations++; + currently_allocated_bytes -= device_allocations[ptr]; + device_allocations.erase(ptr); + } + }; + std::map + stats_; // Map device ordinal to memory information + std::mutex mutex_; + +public: + void RegisterAllocation(void *ptr, size_t n) { + if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) + return; + std::lock_guard guard(mutex_); + int current_device; + safe_cuda(cudaGetDevice(¤t_device)); + stats_[current_device].RegisterAllocation(ptr, n); + } + void RegisterDeallocation(void *ptr) { + if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) + return; + std::lock_guard guard(mutex_); + int current_device; + safe_cuda(cudaGetDevice(¤t_device)); + stats_[current_device].RegisterDeallocation(ptr); + } + void Log() { + if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) + return; + std::lock_guard guard(mutex_); + for (const auto &kv : stats_) { + LOG(CONSOLE) << "======== Device " << kv.first << " Memory Allocations: " + << " ========"; + LOG(CONSOLE) << "Peak memory usage: " + << kv.second.peak_allocated_bytes / 1000000 << "mb"; + LOG(CONSOLE) << "Number of allocations: " << kv.second.num_allocations; + } + } +}; +}; + +inline detail::MemoryLogger &GlobalMemoryLogger() { + static detail::MemoryLogger memory_logger; + return memory_logger; +} + +namespace detail{ +/** + * \brief Default memory allocator, uses cudaMalloc/Free and logs allocations if verbose. + */ +template +struct XGBDefaultDeviceAllocator : thrust::device_malloc_allocator { + using super_t = thrust::device_malloc_allocator; + using pointer = thrust::device_ptr; + pointer allocate(size_t n) { + pointer ptr = super_t::allocate(n); + GlobalMemoryLogger().RegisterAllocation(ptr.get(), n); + return ptr; + } + void deallocate(pointer ptr, size_t n) { + GlobalMemoryLogger().RegisterDeallocation(ptr.get()); + return super_t::deallocate(ptr, n); + } +}; +}; + +// Declare xgboost allocator +// Replacement of allocator with custom backend should occur here +template +using XGBDeviceAllocator = detail::XGBDefaultDeviceAllocator; +/** \brief Specialisation of thrust device vector using custom allocator. */ +template +using device_vector = thrust::device_vector>; /** * \brief A double buffer, useful for algorithms like sort. @@ -335,10 +422,9 @@ class BulkAllocator { } char *AllocateDevice(int device_idx, size_t bytes) { - char *ptr; safe_cuda(cudaSetDevice(device_idx)); - safe_cuda(cudaMalloc(&ptr, bytes)); - return ptr; + XGBDeviceAllocator allocator; + return allocator.allocate(bytes).get(); } template @@ -383,7 +469,8 @@ class BulkAllocator { for (size_t i = 0; i < d_ptr_.size(); i++) { if (!(d_ptr_[i] == nullptr)) { safe_cuda(cudaSetDevice(device_idx_[i])); - safe_cuda(cudaFree(d_ptr_[i])); + XGBDeviceAllocator allocator; + allocator.deallocate(thrust::device_ptr(d_ptr_[i]), size_[i]); d_ptr_[i] = nullptr; } } @@ -453,14 +540,17 @@ struct CubMemory { void Free() { if (this->IsAllocated()) { - safe_cuda(cudaFree(d_temp_storage)); + XGBDeviceAllocator allocator; + allocator.deallocate(thrust::device_ptr(static_cast(d_temp_storage)), + temp_storage_bytes); } } void LazyAllocate(size_t num_bytes) { if (num_bytes > temp_storage_bytes) { Free(); - safe_cuda(cudaMalloc(&d_temp_storage, num_bytes)); + XGBDeviceAllocator allocator; + d_temp_storage = static_cast(allocator.allocate(num_bytes).get()); temp_storage_bytes = num_bytes; } } @@ -1119,7 +1209,7 @@ ReduceT ReduceShards(std::vector *shards, FunctionT f) { template ::index_type> xgboost::common::Span ToSpan( - thrust::device_vector& vec, + device_vector& vec, IndexT offset = 0, IndexT size = -1) { size = size == -1 ? vec.size() : size; diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 1ee1267aa..c7fc384d8 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -130,18 +130,18 @@ struct GPUSketcher { tree::TrainParam param_; SketchContainer *sketch_container_; - thrust::device_vector row_ptrs_; - thrust::device_vector entries_; - thrust::device_vector fvalues_; - thrust::device_vector feature_weights_; - thrust::device_vector fvalues_cur_; - thrust::device_vector cuts_d_; + dh::device_vector row_ptrs_; + dh::device_vector entries_; + dh::device_vector fvalues_; + dh::device_vector feature_weights_; + dh::device_vector fvalues_cur_; + dh::device_vector cuts_d_; thrust::host_vector cuts_h_; - thrust::device_vector weights_; - thrust::device_vector weights2_; + dh::device_vector weights_; + dh::device_vector weights2_; std::vector n_cuts_cur_; - thrust::device_vector num_elements_; - thrust::device_vector tmp_storage_; + dh::device_vector num_elements_; + dh::device_vector tmp_storage_; public: DeviceShard(int device, bst_uint row_begin, bst_uint row_end, diff --git a/src/common/host_device_vector.cu b/src/common/host_device_vector.cu index 7ed173188..487ae1436 100644 --- a/src/common/host_device_vector.cu +++ b/src/common/host_device_vector.cu @@ -161,7 +161,7 @@ struct HostDeviceVectorImpl { private: int device_; - thrust::device_vector data_; + dh::device_vector data_; // cached vector size size_t cached_size_; size_t start_; diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 20d4d5e10..d830d1f7b 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -261,15 +261,15 @@ class GPUPredictor : public xgboost::Predictor { size_t tree_begin, size_t tree_end) { dh::safe_cuda(cudaSetDevice(device_)); nodes_.resize(h_nodes.size()); - dh::safe_cuda(cudaMemcpyAsync(dh::Raw(nodes_), h_nodes.data(), + dh::safe_cuda(cudaMemcpyAsync(nodes_.data().get(), h_nodes.data(), sizeof(DevicePredictionNode) * h_nodes.size(), cudaMemcpyHostToDevice)); tree_segments_.resize(h_tree_segments.size()); - dh::safe_cuda(cudaMemcpyAsync(dh::Raw(tree_segments_), h_tree_segments.data(), + dh::safe_cuda(cudaMemcpyAsync(tree_segments_.data().get(), h_tree_segments.data(), sizeof(size_t) * h_tree_segments.size(), cudaMemcpyHostToDevice)); tree_group_.resize(model.tree_info.size()); - dh::safe_cuda(cudaMemcpyAsync(dh::Raw(tree_group_), model.tree_info.data(), + dh::safe_cuda(cudaMemcpyAsync(tree_group_.data().get(), model.tree_info.data(), sizeof(int) * model.tree_info.size(), cudaMemcpyHostToDevice)); this->tree_begin_ = tree_begin; @@ -306,9 +306,9 @@ class GPUPredictor : public xgboost::Predictor { private: int device_; - thrust::device_vector nodes_; - thrust::device_vector tree_segments_; - thrust::device_vector tree_group_; + dh::device_vector nodes_; + dh::device_vector tree_segments_; + dh::device_vector tree_group_; size_t max_shared_memory_bytes_; size_t tree_begin_; size_t tree_end_; @@ -373,7 +373,7 @@ class GPUPredictor : public xgboost::Predictor { } public: - GPUPredictor() + GPUPredictor() // NOLINT : cpu_predictor_(Predictor::Create("cpu_predictor", learner_param_)) {} void PredictBatch(DMatrix* dmat, HostDeviceVector* out_preds, diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 4286a6c1c..c4884442d 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -383,7 +383,7 @@ class DeviceHistogram { private: /*! \brief Map nidx to starting index of its histogram. */ std::map nidx_map_; - thrust::device_vector data_; + dh::device_vector data_; int n_bins_; int device_id_; static constexpr size_t kNumItemsInGradientSum = @@ -410,7 +410,7 @@ class DeviceHistogram { return n_bins_ * kNumItemsInGradientSum; } - thrust::device_vector& Data() { + dh::device_vector& Data() { return data_; } @@ -667,10 +667,10 @@ struct DeviceShard { std::vector node_sum_gradients; common::Span node_sum_gradients_d; /*! \brief row offset in SparsePage (the input data). */ - thrust::device_vector row_ptrs; + dh::device_vector row_ptrs; /*! \brief On-device feature set, only actually used on one of the devices */ - thrust::device_vector feature_set_d; - thrust::device_vector + dh::device_vector feature_set_d; + dh::device_vector left_counts; // Useful to keep a bunch of zeroed memory for sort position /*! The row offset for this shard. */ bst_uint row_begin_idx; @@ -1304,7 +1304,7 @@ inline void DeviceShard::CreateHistIndices( static_cast(n_rows)); const std::vector& data_vec = row_batch.data.HostVector(); - thrust::device_vector entries_d(gpu_batch_nrows * row_stride); + dh::device_vector entries_d(gpu_batch_nrows * row_stride); size_t gpu_nbatches = dh::DivRoundUp(n_rows, gpu_batch_nrows); for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) { @@ -1362,6 +1362,8 @@ class GPUHistMakerSpecialised { monitor_.Init("updater_gpu_hist"); } + ~GPUHistMakerSpecialised() { dh::GlobalMemoryLogger().Log(); } + void Update(HostDeviceVector* gpair, DMatrix* dmat, const std::vector& trees) { monitor_.StartCuda("Update"); diff --git a/tests/benchmark/benchmark_tree.py b/tests/benchmark/benchmark_tree.py index b055b4ee6..446deb714 100644 --- a/tests/benchmark/benchmark_tree.py +++ b/tests/benchmark/benchmark_tree.py @@ -22,7 +22,7 @@ def run_benchmark(args): if not (dtest.num_row() == args.rows * args.test_size and dtrain.num_row() == args.rows * (1 - args.test_size)): raise ValueError("Wrong rows") - except xgb.core.XGBoostError: + except ValueError: print("Generating dataset: {} rows * {} columns".format(args.rows, args.columns)) print("{}/{} test/train split".format(args.test_size, 1.0 - args.test_size)) tmp = time.time() diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index b357a610f..8afcd6ab9 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -67,3 +67,6 @@ if (USE_OPENMP) target_compile_options(testxgboost PRIVATE $<$:${OpenMP_CXX_FLAGS}>) endif (USE_OPENMP) set_output_directory(testxgboost ${PROJECT_BINARY_DIR}) + +# This grouping organises source files nicely in visual studio +auto_source_group("${TEST_SOURCES}") diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 2d120ecc5..6a59859e3 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -338,8 +338,6 @@ TEST(GpuHist, EvaluateSplits) { } TEST(GpuHist, ApplySplit) { - GPUHistMakerSpecialised hist_maker = - GPUHistMakerSpecialised(); int constexpr kNId = 0; int constexpr kNRows = 16; int constexpr kNCols = 8; @@ -353,11 +351,9 @@ TEST(GpuHist, ApplySplit) { param.monotone_constraints.emplace_back(0); } - hist_maker.shards_.resize(1); - hist_maker.shards_[0].reset( - new DeviceShard(0, 0, 0, kNRows, param, kNCols)); + std::unique_ptr> shard{ + new DeviceShard(0, 0, 0, kNRows, param, kNCols)}; - auto& shard = hist_maker.shards_.at(0); shard->ridx_segments.resize(3); // 3 nodes. shard->node_sum_gradients.resize(3); @@ -368,8 +364,6 @@ TEST(GpuHist, ApplySplit) { thrust::sequence( thrust::device_pointer_cast(shard->ridx.Current()), thrust::device_pointer_cast(shard->ridx.Current() + shard->ridx.Size())); - // Initialize GPUHistMaker - hist_maker.param_ = param; RegTree tree; DeviceSplitCandidate candidate; @@ -382,7 +376,6 @@ TEST(GpuHist, ApplySplit) { // Used to get bin_id in update position. common::HistCutMatrix cmat = GetHostCutMatrix(); - hist_maker.hmat_ = cmat; MetaInfo info; info.num_row_ = kNRows; @@ -421,7 +414,6 @@ TEST(GpuHist, ApplySplit) { shard->ellpack_matrix.gidx_iter = common::CompressedIterator( shard->gidx_buffer.data(), num_symbols); - hist_maker.info_ = &info; shard->ApplySplit(candidate_entry, &tree); shard->UpdatePosition(candidate_entry.nid, tree[candidate_entry.nid]);