Overload device memory allocation (#4532)

* Group source files, include headers in source files

* Overload device memory allocation
This commit is contained in:
Rory Mitchell
2019-06-10 11:35:13 +12:00
committed by GitHub
parent da21ac0cc2
commit 9683fd433e
9 changed files with 140 additions and 49 deletions

View File

@@ -383,7 +383,7 @@ class DeviceHistogram {
private:
/*! \brief Map nidx to starting index of its histogram. */
std::map<int, size_t> nidx_map_;
thrust::device_vector<typename GradientSumT::ValueT> data_;
dh::device_vector<typename GradientSumT::ValueT> data_;
int n_bins_;
int device_id_;
static constexpr size_t kNumItemsInGradientSum =
@@ -410,7 +410,7 @@ class DeviceHistogram {
return n_bins_ * kNumItemsInGradientSum;
}
thrust::device_vector<typename GradientSumT::ValueT>& Data() {
dh::device_vector<typename GradientSumT::ValueT>& Data() {
return data_;
}
@@ -667,10 +667,10 @@ struct DeviceShard {
std::vector<GradientPair> node_sum_gradients;
common::Span<GradientPair> node_sum_gradients_d;
/*! \brief row offset in SparsePage (the input data). */
thrust::device_vector<size_t> row_ptrs;
dh::device_vector<size_t> row_ptrs;
/*! \brief On-device feature set, only actually used on one of the devices */
thrust::device_vector<int> feature_set_d;
thrust::device_vector<int64_t>
dh::device_vector<int> feature_set_d;
dh::device_vector<int64_t>
left_counts; // Useful to keep a bunch of zeroed memory for sort position
/*! The row offset for this shard. */
bst_uint row_begin_idx;
@@ -1304,7 +1304,7 @@ inline void DeviceShard<GradientSumT>::CreateHistIndices(
static_cast<size_t>(n_rows));
const std::vector<Entry>& data_vec = row_batch.data.HostVector();
thrust::device_vector<Entry> entries_d(gpu_batch_nrows * row_stride);
dh::device_vector<Entry> entries_d(gpu_batch_nrows * row_stride);
size_t gpu_nbatches = dh::DivRoundUp(n_rows, gpu_batch_nrows);
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
@@ -1362,6 +1362,8 @@ class GPUHistMakerSpecialised {
monitor_.Init("updater_gpu_hist");
}
~GPUHistMakerSpecialised() { dh::GlobalMemoryLogger().Log(); }
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) {
monitor_.StartCuda("Update");