enable rocm, fix row_partitioner.cuh
This commit is contained in:
parent
0fc1f640a9
commit
270c7b4802
@ -116,7 +116,13 @@ template <typename RowIndexT, typename OpT, typename OpDataT>
|
|||||||
void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
|
void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
|
||||||
common::Span<RowIndexT> ridx, common::Span<RowIndexT> ridx_tmp,
|
common::Span<RowIndexT> ridx, common::Span<RowIndexT> ridx_tmp,
|
||||||
common::Span<bst_uint> d_counts, std::size_t total_rows, OpT op,
|
common::Span<bst_uint> d_counts, std::size_t total_rows, OpT op,
|
||||||
dh::device_vector<int8_t>* tmp, cudaStream_t stream) {
|
dh::device_vector<int8_t>* tmp,
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
hipStream_t stream
|
||||||
|
#else
|
||||||
|
cudaStream_t stream
|
||||||
|
#endif
|
||||||
|
) {
|
||||||
dh::LDGIterator<PerNodeData<OpDataT>> batch_info_itr(d_batch_info.data());
|
dh::LDGIterator<PerNodeData<OpDataT>> batch_info_itr(d_batch_info.data());
|
||||||
WriteResultsFunctor<OpDataT> write_results{batch_info_itr, ridx.data(), ridx_tmp.data(),
|
WriteResultsFunctor<OpDataT> write_results{batch_info_itr, ridx.data(), ridx_tmp.data(),
|
||||||
d_counts.data()};
|
d_counts.data()};
|
||||||
@ -221,7 +227,12 @@ class RowPartitioner {
|
|||||||
dh::device_vector<int8_t> tmp_;
|
dh::device_vector<int8_t> tmp_;
|
||||||
dh::PinnedMemory pinned_;
|
dh::PinnedMemory pinned_;
|
||||||
dh::PinnedMemory pinned2_;
|
dh::PinnedMemory pinned2_;
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
hipStream_t stream_;
|
||||||
|
#else
|
||||||
cudaStream_t stream_;
|
cudaStream_t stream_;
|
||||||
|
#endif
|
||||||
|
|
||||||
public:
|
public:
|
||||||
RowPartitioner(int device_idx, size_t num_rows);
|
RowPartitioner(int device_idx, size_t num_rows);
|
||||||
@ -276,9 +287,16 @@ class RowPartitioner {
|
|||||||
h_batch_info[i] = {ridx_segments_.at(nidx.at(i)).segment, op_data.at(i)};
|
h_batch_info[i] = {ridx_segments_.at(nidx.at(i)).segment, op_data.at(i)};
|
||||||
total_rows += ridx_segments_.at(nidx.at(i)).segment.Size();
|
total_rows += ridx_segments_.at(nidx.at(i)).segment.Size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
dh::safe_cuda(hipMemcpyAsync(d_batch_info.data().get(), h_batch_info.data(),
|
||||||
|
h_batch_info.size() * sizeof(PerNodeData<OpDataT>),
|
||||||
|
hipMemcpyDefault, stream_));
|
||||||
|
#else
|
||||||
dh::safe_cuda(cudaMemcpyAsync(d_batch_info.data().get(), h_batch_info.data(),
|
dh::safe_cuda(cudaMemcpyAsync(d_batch_info.data().get(), h_batch_info.data(),
|
||||||
h_batch_info.size() * sizeof(PerNodeData<OpDataT>),
|
h_batch_info.size() * sizeof(PerNodeData<OpDataT>),
|
||||||
cudaMemcpyDefault, stream_));
|
cudaMemcpyDefault, stream_));
|
||||||
|
#endif
|
||||||
|
|
||||||
// Temporary arrays
|
// Temporary arrays
|
||||||
auto h_counts = pinned_.GetSpan<bst_uint>(nidx.size(), 0);
|
auto h_counts = pinned_.GetSpan<bst_uint>(nidx.size(), 0);
|
||||||
@ -288,11 +306,22 @@ class RowPartitioner {
|
|||||||
SortPositionBatch<RowIndexT, UpdatePositionOpT, OpDataT>(
|
SortPositionBatch<RowIndexT, UpdatePositionOpT, OpDataT>(
|
||||||
dh::ToSpan(d_batch_info), dh::ToSpan(ridx_), dh::ToSpan(ridx_tmp_), dh::ToSpan(d_counts),
|
dh::ToSpan(d_batch_info), dh::ToSpan(ridx_), dh::ToSpan(ridx_tmp_), dh::ToSpan(d_counts),
|
||||||
total_rows, op, &tmp_, stream_);
|
total_rows, op, &tmp_, stream_);
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
dh::safe_cuda(hipMemcpyAsync(h_counts.data(), d_counts.data().get(), h_counts.size_bytes(),
|
||||||
|
hipMemcpyDefault, stream_));
|
||||||
|
#else
|
||||||
dh::safe_cuda(cudaMemcpyAsync(h_counts.data(), d_counts.data().get(), h_counts.size_bytes(),
|
dh::safe_cuda(cudaMemcpyAsync(h_counts.data(), d_counts.data().get(), h_counts.size_bytes(),
|
||||||
cudaMemcpyDefault, stream_));
|
cudaMemcpyDefault, stream_));
|
||||||
|
#endif
|
||||||
|
|
||||||
// TODO(Rory): this synchronisation hurts performance a lot
|
// TODO(Rory): this synchronisation hurts performance a lot
|
||||||
// Future optimisation should find a way to skip this
|
// Future optimisation should find a way to skip this
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
dh::safe_cuda(hipStreamSynchronize(stream_));
|
||||||
|
#else
|
||||||
dh::safe_cuda(cudaStreamSynchronize(stream_));
|
dh::safe_cuda(cudaStreamSynchronize(stream_));
|
||||||
|
#endif
|
||||||
|
|
||||||
// Update segments
|
// Update segments
|
||||||
for (size_t i = 0; i < nidx.size(); i++) {
|
for (size_t i = 0; i < nidx.size(); i++) {
|
||||||
@ -325,9 +354,16 @@ class RowPartitioner {
|
|||||||
template <typename FinalisePositionOpT>
|
template <typename FinalisePositionOpT>
|
||||||
void FinalisePosition(common::Span<bst_node_t> d_out_position, FinalisePositionOpT op) {
|
void FinalisePosition(common::Span<bst_node_t> d_out_position, FinalisePositionOpT op) {
|
||||||
dh::TemporaryArray<NodePositionInfo> d_node_info_storage(ridx_segments_.size());
|
dh::TemporaryArray<NodePositionInfo> d_node_info_storage(ridx_segments_.size());
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
dh::safe_cuda(hipMemcpyAsync(d_node_info_storage.data().get(), ridx_segments_.data(),
|
||||||
|
sizeof(NodePositionInfo) * ridx_segments_.size(),
|
||||||
|
hipMemcpyDefault, stream_));
|
||||||
|
#else
|
||||||
dh::safe_cuda(cudaMemcpyAsync(d_node_info_storage.data().get(), ridx_segments_.data(),
|
dh::safe_cuda(cudaMemcpyAsync(d_node_info_storage.data().get(), ridx_segments_.data(),
|
||||||
sizeof(NodePositionInfo) * ridx_segments_.size(),
|
sizeof(NodePositionInfo) * ridx_segments_.size(),
|
||||||
cudaMemcpyDefault, stream_));
|
cudaMemcpyDefault, stream_));
|
||||||
|
#endif
|
||||||
|
|
||||||
constexpr int kBlockSize = 512;
|
constexpr int kBlockSize = 512;
|
||||||
const int kItemsThread = 8;
|
const int kItemsThread = 8;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user