Overload device memory allocation (#4532)

* Group source files, include headers in source files

* Overload device memory allocation
This commit is contained in:
Rory Mitchell 2019-06-10 11:35:13 +12:00 committed by GitHub
parent da21ac0cc2
commit 9683fd433e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 140 additions and 49 deletions

View File

@ -1,4 +1,4 @@
file(GLOB_RECURSE CPU_SOURCES *.cc)
file(GLOB_RECURSE CPU_SOURCES *.cc *.h)
list(REMOVE_ITEM CPU_SOURCES ${PROJECT_SOURCE_DIR}/src/cli_main.cc)
include(CheckCXXSourceCompiles)
@ -33,7 +33,7 @@ endif (PLUGIN_DENSE_PARSER)
# Object library is necessary for jvm-package, which creates its own shared
# library.
if (USE_CUDA)
file(GLOB_RECURSE CUDA_SOURCES *.cu)
file(GLOB_RECURSE CUDA_SOURCES *.cu *.cuh)
add_library(objxgboost OBJECT ${CPU_SOURCES} ${CUDA_SOURCES} ${PLUGINS_SOURCES})
target_compile_definitions(objxgboost
PRIVATE -DXGBOOST_USE_CUDA=1)
@ -119,4 +119,8 @@ endif (USE_OPENMP)
# for issues caused by mixing of /MD and /MT flags
msvc_use_static_runtime()
# This grouping organises source files nicely in visual studio
auto_source_group("${CUDA_SOURCES}")
auto_source_group("${CPU_SOURCES}")
#-- End object library

View File

@ -4,6 +4,7 @@
#pragma once
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/device_malloc_allocator.h>
#include <thrust/system/cuda/error.h>
#include <thrust/system_error.h>
#include <xgboost/logging.h>
@ -49,11 +50,6 @@ inline ncclResult_t ThrowOnNcclError(ncclResult_t code, const char *file,
}
#endif
template <typename T>
T *Raw(thrust::device_vector<T> &v) { // NOLINT
return raw_pointer_cast(v.data());
}
inline void CudaCheckPointerDevice(void* ptr) {
cudaPointerAttributes attr;
dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr));
@ -225,6 +221,97 @@ inline void LaunchN(int device_idx, size_t n, L lambda) {
LaunchN<ITEMS_PER_THREAD, BLOCK_THREADS>(device_idx, n, nullptr, lambda);
}
namespace detail {
/** \brief Keeps track of global device memory allocations. Thread safe.*/
class MemoryLogger {
// Information for a single device
struct DeviceStats {
size_t currently_allocated_bytes{ 0 };
size_t peak_allocated_bytes{ 0 };
size_t num_allocations{ 0 };
size_t num_deallocations{ 0 };
std::map<void *, size_t> device_allocations;
void RegisterAllocation(void *ptr, size_t n) {
device_allocations[ptr] = n;
currently_allocated_bytes += n;
peak_allocated_bytes =
std::max(peak_allocated_bytes, currently_allocated_bytes);
num_allocations++;
}
void RegisterDeallocation(void *ptr) {
num_deallocations++;
currently_allocated_bytes -= device_allocations[ptr];
device_allocations.erase(ptr);
}
};
std::map<int, DeviceStats>
stats_; // Map device ordinal to memory information
std::mutex mutex_;
public:
void RegisterAllocation(void *ptr, size_t n) {
if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug))
return;
std::lock_guard<std::mutex> guard(mutex_);
int current_device;
safe_cuda(cudaGetDevice(&current_device));
stats_[current_device].RegisterAllocation(ptr, n);
}
void RegisterDeallocation(void *ptr) {
if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug))
return;
std::lock_guard<std::mutex> guard(mutex_);
int current_device;
safe_cuda(cudaGetDevice(&current_device));
stats_[current_device].RegisterDeallocation(ptr);
}
void Log() {
if (!xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug))
return;
std::lock_guard<std::mutex> guard(mutex_);
for (const auto &kv : stats_) {
LOG(CONSOLE) << "======== Device " << kv.first << " Memory Allocations: "
<< " ========";
LOG(CONSOLE) << "Peak memory usage: "
<< kv.second.peak_allocated_bytes / 1000000 << "mb";
LOG(CONSOLE) << "Number of allocations: " << kv.second.num_allocations;
}
}
};
};
inline detail::MemoryLogger &GlobalMemoryLogger() {
static detail::MemoryLogger memory_logger;
return memory_logger;
}
namespace detail{
/**
* \brief Default memory allocator, uses cudaMalloc/Free and logs allocations if verbose.
*/
template <class T>
struct XGBDefaultDeviceAllocator : thrust::device_malloc_allocator<T> {
using super_t = thrust::device_malloc_allocator<T>;
using pointer = thrust::device_ptr<T>;
pointer allocate(size_t n) {
pointer ptr = super_t::allocate(n);
GlobalMemoryLogger().RegisterAllocation(ptr.get(), n);
return ptr;
}
void deallocate(pointer ptr, size_t n) {
GlobalMemoryLogger().RegisterDeallocation(ptr.get());
return super_t::deallocate(ptr, n);
}
};
};
// Declare xgboost allocator
// Replacement of allocator with custom backend should occur here
template <typename T>
using XGBDeviceAllocator = detail::XGBDefaultDeviceAllocator<T>;
/** \brief Specialisation of thrust device vector using custom allocator. */
template <typename T>
using device_vector = thrust::device_vector<T, XGBDeviceAllocator<T>>;
/**
* \brief A double buffer, useful for algorithms like sort.
@ -335,10 +422,9 @@ class BulkAllocator {
}
char *AllocateDevice(int device_idx, size_t bytes) {
char *ptr;
safe_cuda(cudaSetDevice(device_idx));
safe_cuda(cudaMalloc(&ptr, bytes));
return ptr;
XGBDeviceAllocator<char> allocator;
return allocator.allocate(bytes).get();
}
template <typename T>
@ -383,7 +469,8 @@ class BulkAllocator {
for (size_t i = 0; i < d_ptr_.size(); i++) {
if (!(d_ptr_[i] == nullptr)) {
safe_cuda(cudaSetDevice(device_idx_[i]));
safe_cuda(cudaFree(d_ptr_[i]));
XGBDeviceAllocator<char> allocator;
allocator.deallocate(thrust::device_ptr<char>(d_ptr_[i]), size_[i]);
d_ptr_[i] = nullptr;
}
}
@ -453,14 +540,17 @@ struct CubMemory {
void Free() {
if (this->IsAllocated()) {
safe_cuda(cudaFree(d_temp_storage));
XGBDeviceAllocator<uint8_t> allocator;
allocator.deallocate(thrust::device_ptr<uint8_t>(static_cast<uint8_t *>(d_temp_storage)),
temp_storage_bytes);
}
}
void LazyAllocate(size_t num_bytes) {
if (num_bytes > temp_storage_bytes) {
Free();
safe_cuda(cudaMalloc(&d_temp_storage, num_bytes));
XGBDeviceAllocator<uint8_t> allocator;
d_temp_storage = static_cast<void *>(allocator.allocate(num_bytes).get());
temp_storage_bytes = num_bytes;
}
}
@ -1119,7 +1209,7 @@ ReduceT ReduceShards(std::vector<ShardT> *shards, FunctionT f) {
template <typename T,
typename IndexT = typename xgboost::common::Span<T>::index_type>
xgboost::common::Span<T> ToSpan(
thrust::device_vector<T>& vec,
device_vector<T>& vec,
IndexT offset = 0,
IndexT size = -1) {
size = size == -1 ? vec.size() : size;

View File

@ -130,18 +130,18 @@ struct GPUSketcher {
tree::TrainParam param_;
SketchContainer *sketch_container_;
thrust::device_vector<size_t> row_ptrs_;
thrust::device_vector<Entry> entries_;
thrust::device_vector<bst_float> fvalues_;
thrust::device_vector<bst_float> feature_weights_;
thrust::device_vector<bst_float> fvalues_cur_;
thrust::device_vector<WXQSketch::Entry> cuts_d_;
dh::device_vector<size_t> row_ptrs_;
dh::device_vector<Entry> entries_;
dh::device_vector<bst_float> fvalues_;
dh::device_vector<bst_float> feature_weights_;
dh::device_vector<bst_float> fvalues_cur_;
dh::device_vector<WXQSketch::Entry> cuts_d_;
thrust::host_vector<WXQSketch::Entry> cuts_h_;
thrust::device_vector<bst_float> weights_;
thrust::device_vector<bst_float> weights2_;
dh::device_vector<bst_float> weights_;
dh::device_vector<bst_float> weights2_;
std::vector<size_t> n_cuts_cur_;
thrust::device_vector<size_t> num_elements_;
thrust::device_vector<char> tmp_storage_;
dh::device_vector<size_t> num_elements_;
dh::device_vector<char> tmp_storage_;
public:
DeviceShard(int device, bst_uint row_begin, bst_uint row_end,

View File

@ -161,7 +161,7 @@ struct HostDeviceVectorImpl {
private:
int device_;
thrust::device_vector<T> data_;
dh::device_vector<T> data_;
// cached vector size
size_t cached_size_;
size_t start_;

View File

@ -261,15 +261,15 @@ class GPUPredictor : public xgboost::Predictor {
size_t tree_begin, size_t tree_end) {
dh::safe_cuda(cudaSetDevice(device_));
nodes_.resize(h_nodes.size());
dh::safe_cuda(cudaMemcpyAsync(dh::Raw(nodes_), h_nodes.data(),
dh::safe_cuda(cudaMemcpyAsync(nodes_.data().get(), h_nodes.data(),
sizeof(DevicePredictionNode) * h_nodes.size(),
cudaMemcpyHostToDevice));
tree_segments_.resize(h_tree_segments.size());
dh::safe_cuda(cudaMemcpyAsync(dh::Raw(tree_segments_), h_tree_segments.data(),
dh::safe_cuda(cudaMemcpyAsync(tree_segments_.data().get(), h_tree_segments.data(),
sizeof(size_t) * h_tree_segments.size(),
cudaMemcpyHostToDevice));
tree_group_.resize(model.tree_info.size());
dh::safe_cuda(cudaMemcpyAsync(dh::Raw(tree_group_), model.tree_info.data(),
dh::safe_cuda(cudaMemcpyAsync(tree_group_.data().get(), model.tree_info.data(),
sizeof(int) * model.tree_info.size(),
cudaMemcpyHostToDevice));
this->tree_begin_ = tree_begin;
@ -306,9 +306,9 @@ class GPUPredictor : public xgboost::Predictor {
private:
int device_;
thrust::device_vector<DevicePredictionNode> nodes_;
thrust::device_vector<size_t> tree_segments_;
thrust::device_vector<int> tree_group_;
dh::device_vector<DevicePredictionNode> nodes_;
dh::device_vector<size_t> tree_segments_;
dh::device_vector<int> tree_group_;
size_t max_shared_memory_bytes_;
size_t tree_begin_;
size_t tree_end_;
@ -373,7 +373,7 @@ class GPUPredictor : public xgboost::Predictor {
}
public:
GPUPredictor()
GPUPredictor() // NOLINT
: cpu_predictor_(Predictor::Create("cpu_predictor", learner_param_)) {}
void PredictBatch(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,

View File

@ -383,7 +383,7 @@ class DeviceHistogram {
private:
/*! \brief Map nidx to starting index of its histogram. */
std::map<int, size_t> nidx_map_;
thrust::device_vector<typename GradientSumT::ValueT> data_;
dh::device_vector<typename GradientSumT::ValueT> data_;
int n_bins_;
int device_id_;
static constexpr size_t kNumItemsInGradientSum =
@ -410,7 +410,7 @@ class DeviceHistogram {
return n_bins_ * kNumItemsInGradientSum;
}
thrust::device_vector<typename GradientSumT::ValueT>& Data() {
dh::device_vector<typename GradientSumT::ValueT>& Data() {
return data_;
}
@ -667,10 +667,10 @@ struct DeviceShard {
std::vector<GradientPair> node_sum_gradients;
common::Span<GradientPair> node_sum_gradients_d;
/*! \brief row offset in SparsePage (the input data). */
thrust::device_vector<size_t> row_ptrs;
dh::device_vector<size_t> row_ptrs;
/*! \brief On-device feature set, only actually used on one of the devices */
thrust::device_vector<int> feature_set_d;
thrust::device_vector<int64_t>
dh::device_vector<int> feature_set_d;
dh::device_vector<int64_t>
left_counts; // Useful to keep a bunch of zeroed memory for sort position
/*! The row offset for this shard. */
bst_uint row_begin_idx;
@ -1304,7 +1304,7 @@ inline void DeviceShard<GradientSumT>::CreateHistIndices(
static_cast<size_t>(n_rows));
const std::vector<Entry>& data_vec = row_batch.data.HostVector();
thrust::device_vector<Entry> entries_d(gpu_batch_nrows * row_stride);
dh::device_vector<Entry> entries_d(gpu_batch_nrows * row_stride);
size_t gpu_nbatches = dh::DivRoundUp(n_rows, gpu_batch_nrows);
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
@ -1362,6 +1362,8 @@ class GPUHistMakerSpecialised {
monitor_.Init("updater_gpu_hist");
}
~GPUHistMakerSpecialised() { dh::GlobalMemoryLogger().Log(); }
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) {
monitor_.StartCuda("Update");

View File

@ -22,7 +22,7 @@ def run_benchmark(args):
if not (dtest.num_row() == args.rows * args.test_size
and dtrain.num_row() == args.rows * (1 - args.test_size)):
raise ValueError("Wrong rows")
except xgb.core.XGBoostError:
except ValueError:
print("Generating dataset: {} rows * {} columns".format(args.rows, args.columns))
print("{}/{} test/train split".format(args.test_size, 1.0 - args.test_size))
tmp = time.time()

View File

@ -67,3 +67,6 @@ if (USE_OPENMP)
target_compile_options(testxgboost PRIVATE $<$<COMPILE_LANGUAGE:CXX>:${OpenMP_CXX_FLAGS}>)
endif (USE_OPENMP)
set_output_directory(testxgboost ${PROJECT_BINARY_DIR})
# This grouping organises source files nicely in visual studio
auto_source_group("${TEST_SOURCES}")

View File

@ -338,8 +338,6 @@ TEST(GpuHist, EvaluateSplits) {
}
TEST(GpuHist, ApplySplit) {
GPUHistMakerSpecialised<GradientPairPrecise> hist_maker =
GPUHistMakerSpecialised<GradientPairPrecise>();
int constexpr kNId = 0;
int constexpr kNRows = 16;
int constexpr kNCols = 8;
@ -353,11 +351,9 @@ TEST(GpuHist, ApplySplit) {
param.monotone_constraints.emplace_back(0);
}
hist_maker.shards_.resize(1);
hist_maker.shards_[0].reset(
new DeviceShard<GradientPairPrecise>(0, 0, 0, kNRows, param, kNCols));
std::unique_ptr<DeviceShard<GradientPairPrecise>> shard{
new DeviceShard<GradientPairPrecise>(0, 0, 0, kNRows, param, kNCols)};
auto& shard = hist_maker.shards_.at(0);
shard->ridx_segments.resize(3); // 3 nodes.
shard->node_sum_gradients.resize(3);
@ -368,8 +364,6 @@ TEST(GpuHist, ApplySplit) {
thrust::sequence(
thrust::device_pointer_cast(shard->ridx.Current()),
thrust::device_pointer_cast(shard->ridx.Current() + shard->ridx.Size()));
// Initialize GPUHistMaker
hist_maker.param_ = param;
RegTree tree;
DeviceSplitCandidate candidate;
@ -382,7 +376,6 @@ TEST(GpuHist, ApplySplit) {
// Used to get bin_id in update position.
common::HistCutMatrix cmat = GetHostCutMatrix();
hist_maker.hmat_ = cmat;
MetaInfo info;
info.num_row_ = kNRows;
@ -421,7 +414,6 @@ TEST(GpuHist, ApplySplit) {
shard->ellpack_matrix.gidx_iter = common::CompressedIterator<uint32_t>(
shard->gidx_buffer.data(), num_symbols);
hist_maker.info_ = &info;
shard->ApplySplit(candidate_entry, &tree);
shard->UpdatePosition(candidate_entry.nid, tree[candidate_entry.nid]);