Overload device memory allocation (#4532)
* Group source files, include headers in source files * Overload device memory allocation
This commit is contained in:
@@ -383,7 +383,7 @@ class DeviceHistogram {
|
||||
private:
|
||||
/*! \brief Map nidx to starting index of its histogram. */
|
||||
std::map<int, size_t> nidx_map_;
|
||||
thrust::device_vector<typename GradientSumT::ValueT> data_;
|
||||
dh::device_vector<typename GradientSumT::ValueT> data_;
|
||||
int n_bins_;
|
||||
int device_id_;
|
||||
static constexpr size_t kNumItemsInGradientSum =
|
||||
@@ -410,7 +410,7 @@ class DeviceHistogram {
|
||||
return n_bins_ * kNumItemsInGradientSum;
|
||||
}
|
||||
|
||||
thrust::device_vector<typename GradientSumT::ValueT>& Data() {
|
||||
dh::device_vector<typename GradientSumT::ValueT>& Data() {
|
||||
return data_;
|
||||
}
|
||||
|
||||
@@ -667,10 +667,10 @@ struct DeviceShard {
|
||||
std::vector<GradientPair> node_sum_gradients;
|
||||
common::Span<GradientPair> node_sum_gradients_d;
|
||||
/*! \brief row offset in SparsePage (the input data). */
|
||||
thrust::device_vector<size_t> row_ptrs;
|
||||
dh::device_vector<size_t> row_ptrs;
|
||||
/*! \brief On-device feature set, only actually used on one of the devices */
|
||||
thrust::device_vector<int> feature_set_d;
|
||||
thrust::device_vector<int64_t>
|
||||
dh::device_vector<int> feature_set_d;
|
||||
dh::device_vector<int64_t>
|
||||
left_counts; // Useful to keep a bunch of zeroed memory for sort position
|
||||
/*! The row offset for this shard. */
|
||||
bst_uint row_begin_idx;
|
||||
@@ -1304,7 +1304,7 @@ inline void DeviceShard<GradientSumT>::CreateHistIndices(
|
||||
static_cast<size_t>(n_rows));
|
||||
const std::vector<Entry>& data_vec = row_batch.data.HostVector();
|
||||
|
||||
thrust::device_vector<Entry> entries_d(gpu_batch_nrows * row_stride);
|
||||
dh::device_vector<Entry> entries_d(gpu_batch_nrows * row_stride);
|
||||
size_t gpu_nbatches = dh::DivRoundUp(n_rows, gpu_batch_nrows);
|
||||
|
||||
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
|
||||
@@ -1362,6 +1362,8 @@ class GPUHistMakerSpecialised {
|
||||
monitor_.Init("updater_gpu_hist");
|
||||
}
|
||||
|
||||
~GPUHistMakerSpecialised() { dh::GlobalMemoryLogger().Log(); }
|
||||
|
||||
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
const std::vector<RegTree*>& trees) {
|
||||
monitor_.StartCuda("Update");
|
||||
|
||||
Reference in New Issue
Block a user