Retire DVec class in favour of c++20 style span for device memory. (#4293)
This commit is contained in:
@@ -601,7 +601,7 @@ struct DeviceShard {
|
||||
int n_bins;
|
||||
int device_id;
|
||||
|
||||
dh::BulkAllocator<dh::MemoryType::kDevice> ba;
|
||||
dh::BulkAllocator ba;
|
||||
|
||||
ELLPackMatrix ellpack_matrix;
|
||||
|
||||
@@ -610,27 +610,26 @@ struct DeviceShard {
|
||||
DeviceHistogram<GradientSumT> hist;
|
||||
|
||||
/*! \brief row_ptr form HistCutMatrix. */
|
||||
dh::DVec<uint32_t> feature_segments;
|
||||
common::Span<uint32_t> feature_segments;
|
||||
/*! \brief minimum value for each feature. */
|
||||
dh::DVec<bst_float> min_fvalue;
|
||||
common::Span<bst_float> min_fvalue;
|
||||
/*! \brief Cut. */
|
||||
dh::DVec<bst_float> gidx_fvalue_map;
|
||||
common::Span<bst_float> gidx_fvalue_map;
|
||||
/*! \brief global index of histogram, which is stored in ELLPack format. */
|
||||
dh::DVec<common::CompressedByteT> gidx_buffer;
|
||||
common::Span<common::CompressedByteT> gidx_buffer;
|
||||
|
||||
/*! \brief Row indices relative to this shard, necessary for sorting rows. */
|
||||
dh::DVec2<bst_uint> ridx;
|
||||
dh::DoubleBuffer<bst_uint> ridx;
|
||||
dh::DoubleBuffer<int> position;
|
||||
/*! \brief Gradient pair for each row. */
|
||||
dh::DVec<GradientPair> gpair;
|
||||
common::Span<GradientPair> gpair;
|
||||
|
||||
dh::DVec2<int> position;
|
||||
|
||||
dh::DVec<int> monotone_constraints;
|
||||
dh::DVec<bst_float> prediction_cache;
|
||||
common::Span<int> monotone_constraints;
|
||||
common::Span<bst_float> prediction_cache;
|
||||
|
||||
/*! \brief Sum gradient for each node. */
|
||||
std::vector<GradientPair> node_sum_gradients;
|
||||
dh::DVec<GradientPair> node_sum_gradients_d;
|
||||
common::Span<GradientPair> node_sum_gradients_d;
|
||||
/*! \brief row offset in SparsePage (the input data). */
|
||||
thrust::device_vector<size_t> row_ptrs;
|
||||
/*! \brief On-device feature set, only actually used on one of the devices */
|
||||
@@ -718,7 +717,9 @@ struct DeviceShard {
|
||||
// Reset values for each update iteration
|
||||
void Reset(HostDeviceVector<GradientPair>* dh_gpair) {
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
position.CurrentDVec().Fill(0);
|
||||
thrust::fill(
|
||||
thrust::device_pointer_cast(position.Current()),
|
||||
thrust::device_pointer_cast(position.Current() + position.Size()), 0);
|
||||
std::fill(node_sum_gradients.begin(), node_sum_gradients.end(),
|
||||
GradientPair());
|
||||
if (left_counts.size() < 256) {
|
||||
@@ -727,13 +728,16 @@ struct DeviceShard {
|
||||
dh::safe_cuda(cudaMemsetAsync(left_counts.data().get(), 0,
|
||||
sizeof(int64_t) * left_counts.size()));
|
||||
}
|
||||
thrust::sequence(ridx.CurrentDVec().tbegin(), ridx.CurrentDVec().tend());
|
||||
thrust::sequence(
|
||||
thrust::device_pointer_cast(ridx.CurrentSpan().data()),
|
||||
thrust::device_pointer_cast(ridx.CurrentSpan().data() + ridx.Size()));
|
||||
|
||||
std::fill(ridx_segments.begin(), ridx_segments.end(), Segment(0, 0));
|
||||
ridx_segments.front() = Segment(0, ridx.Size());
|
||||
this->gpair.copy(dh_gpair->tcbegin(device_id),
|
||||
dh_gpair->tcend(device_id));
|
||||
SubsampleGradientPair(&gpair, param.subsample, row_begin_idx);
|
||||
dh::safe_cuda(cudaMemcpyAsync(
|
||||
gpair.data(), dh_gpair->ConstDevicePointer(device_id),
|
||||
gpair.size() * sizeof(GradientPair), cudaMemcpyHostToHost));
|
||||
SubsampleGradientPair(device_id, gpair, param.subsample, row_begin_idx);
|
||||
hist.Reset();
|
||||
}
|
||||
|
||||
@@ -788,7 +792,7 @@ struct DeviceShard {
|
||||
<<<uint32_t(d_feature_set.size()), kBlockThreads, 0, streams[i]>>>(
|
||||
hist.GetNodeHistogram(nidx), d_feature_set, node, ellpack_matrix,
|
||||
gpu_param, d_split_candidates, value_constraints[nidx],
|
||||
monotone_constraints.GetSpan());
|
||||
monotone_constraints);
|
||||
|
||||
// Reduce over features to find best feature
|
||||
auto d_result = d_result_all.subspan(i, 1);
|
||||
@@ -943,8 +947,8 @@ struct DeviceShard {
|
||||
void UpdatePredictionCache(bst_float* out_preds_d) {
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
if (!prediction_cache_initialised) {
|
||||
dh::safe_cuda(cudaMemcpyAsync(prediction_cache.Data(), out_preds_d,
|
||||
prediction_cache.Size() * sizeof(bst_float),
|
||||
dh::safe_cuda(cudaMemcpyAsync(prediction_cache.data(), out_preds_d,
|
||||
prediction_cache.size() * sizeof(bst_float),
|
||||
cudaMemcpyDefault));
|
||||
}
|
||||
prediction_cache_initialised = true;
|
||||
@@ -952,16 +956,16 @@ struct DeviceShard {
|
||||
CalcWeightTrainParam param_d(param);
|
||||
|
||||
dh::safe_cuda(
|
||||
cudaMemcpyAsync(node_sum_gradients_d.Data(), node_sum_gradients.data(),
|
||||
cudaMemcpyAsync(node_sum_gradients_d.data(), node_sum_gradients.data(),
|
||||
sizeof(GradientPair) * node_sum_gradients.size(),
|
||||
cudaMemcpyHostToDevice));
|
||||
auto d_position = position.Current();
|
||||
auto d_ridx = ridx.Current();
|
||||
auto d_node_sum_gradients = node_sum_gradients_d.Data();
|
||||
auto d_prediction_cache = prediction_cache.Data();
|
||||
auto d_node_sum_gradients = node_sum_gradients_d.data();
|
||||
auto d_prediction_cache = prediction_cache.data();
|
||||
|
||||
dh::LaunchN(
|
||||
device_id, prediction_cache.Size(), [=] __device__(int local_idx) {
|
||||
device_id, prediction_cache.size(), [=] __device__(int local_idx) {
|
||||
int pos = d_position[local_idx];
|
||||
bst_float weight = CalcWeight(param_d, d_node_sum_gradients[pos]);
|
||||
d_prediction_cache[d_ridx[local_idx]] +=
|
||||
@@ -969,8 +973,8 @@ struct DeviceShard {
|
||||
});
|
||||
|
||||
dh::safe_cuda(cudaMemcpy(
|
||||
out_preds_d, prediction_cache.Data(),
|
||||
prediction_cache.Size() * sizeof(bst_float), cudaMemcpyDefault));
|
||||
out_preds_d, prediction_cache.data(),
|
||||
prediction_cache.size() * sizeof(bst_float), cudaMemcpyDefault));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -981,7 +985,7 @@ struct SharedMemHistBuilder : public GPUHistBuilderBase<GradientSumT> {
|
||||
auto segment_begin = segment.begin;
|
||||
auto d_node_hist = shard->hist.GetNodeHistogram(nidx);
|
||||
auto d_ridx = shard->ridx.Current();
|
||||
auto d_gpair = shard->gpair.Data();
|
||||
auto d_gpair = shard->gpair.data();
|
||||
|
||||
auto n_elements = segment.Size() * shard->ellpack_matrix.row_stride;
|
||||
|
||||
@@ -1006,7 +1010,7 @@ struct GlobalMemHistBuilder : public GPUHistBuilderBase<GradientSumT> {
|
||||
Segment segment = shard->ridx_segments[nidx];
|
||||
auto d_node_hist = shard->hist.GetNodeHistogram(nidx).data();
|
||||
bst_uint* d_ridx = shard->ridx.Current();
|
||||
GradientPair* d_gpair = shard->gpair.Data();
|
||||
GradientPair* d_gpair = shard->gpair.data();
|
||||
|
||||
size_t const n_elements = segment.Size() * shard->ellpack_matrix.row_stride;
|
||||
auto d_matrix = shard->ellpack_matrix;
|
||||
@@ -1043,10 +1047,11 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
|
||||
&gidx_fvalue_map, hmat.cut.size(),
|
||||
&min_fvalue, hmat.min_val.size(),
|
||||
&monotone_constraints, param.monotone_constraints.size());
|
||||
gidx_fvalue_map = hmat.cut;
|
||||
min_fvalue = hmat.min_val;
|
||||
feature_segments = hmat.row_ptr;
|
||||
monotone_constraints = param.monotone_constraints;
|
||||
|
||||
dh::CopyVectorToDeviceSpan(gidx_fvalue_map, hmat.cut);
|
||||
dh::CopyVectorToDeviceSpan(min_fvalue, hmat.min_val);
|
||||
dh::CopyVectorToDeviceSpan(feature_segments, hmat.row_ptr);
|
||||
dh::CopyVectorToDeviceSpan(monotone_constraints, param.monotone_constraints);
|
||||
|
||||
node_sum_gradients.resize(max_nodes);
|
||||
ridx_segments.resize(max_nodes);
|
||||
@@ -1063,14 +1068,16 @@ inline void DeviceShard<GradientSumT>::InitCompressedData(
|
||||
<< "Max leaves and max depth cannot both be unconstrained for "
|
||||
"gpu_hist.";
|
||||
ba.Allocate(device_id, &gidx_buffer, compressed_size_bytes);
|
||||
gidx_buffer.Fill(0);
|
||||
thrust::fill(
|
||||
thrust::device_pointer_cast(gidx_buffer.data()),
|
||||
thrust::device_pointer_cast(gidx_buffer.data() + gidx_buffer.size()), 0);
|
||||
|
||||
this->CreateHistIndices(row_batch, row_stride, null_gidx_value);
|
||||
|
||||
ellpack_matrix.Init(
|
||||
feature_segments.GetSpan(), min_fvalue.GetSpan(),
|
||||
gidx_fvalue_map.GetSpan(), row_stride,
|
||||
common::CompressedIterator<uint32_t>(gidx_buffer.Data(), num_symbols),
|
||||
feature_segments, min_fvalue,
|
||||
gidx_fvalue_map, row_stride,
|
||||
common::CompressedIterator<uint32_t>(gidx_buffer.data(), num_symbols),
|
||||
is_dense, null_gidx_value);
|
||||
|
||||
// check if we can use shared memory for building histograms
|
||||
@@ -1121,10 +1128,10 @@ inline void DeviceShard<GradientSumT>::CreateHistIndices(
|
||||
dh::DivRoundUp(row_stride, block3.y), 1);
|
||||
CompressBinEllpackKernel<<<grid3, block3>>>
|
||||
(common::CompressedBufferWriter(num_symbols),
|
||||
gidx_buffer.Data(),
|
||||
gidx_buffer.data(),
|
||||
row_ptrs.data().get() + batch_row_begin,
|
||||
entries_d.data().get(),
|
||||
gidx_fvalue_map.Data(), feature_segments.Data(),
|
||||
gidx_fvalue_map.data(), feature_segments.data(),
|
||||
batch_row_begin, batch_nrows,
|
||||
row_ptrs[batch_row_begin],
|
||||
row_stride, null_gidx_value);
|
||||
@@ -1355,7 +1362,7 @@ class GPUHistMakerSpecialised{
|
||||
[&](int i, std::unique_ptr<DeviceShard<GradientSumT>>& shard) {
|
||||
dh::safe_cuda(cudaSetDevice(shard->device_id));
|
||||
tmp_sums[i] = dh::SumReduction(
|
||||
shard->temp_memory, shard->gpair.Data(), shard->gpair.Size());
|
||||
shard->temp_memory, shard->gpair.data(), shard->gpair.size());
|
||||
});
|
||||
|
||||
GradientPair sum_gradient =
|
||||
|
||||
Reference in New Issue
Block a user