finish row_partitioner.cu

This commit is contained in:
amdsc21 2023-03-09 22:31:11 +01:00
parent 495816f694
commit 500428cc0f
2 changed files with 25 additions and 0 deletions

View File

@ -7,7 +7,12 @@
#include <vector>
#if defined(XGBOOST_USE_CUDA)
#include "../../common/device_helpers.cuh"
#elif defined(XGBOOST_USE_HIP)
#include "../../common/device_helpers.hip.h"
#endif
#include "row_partitioner.cuh"
namespace xgboost {
@ -15,15 +20,31 @@ namespace tree {
RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)
: device_idx_(device_idx), ridx_(num_rows), ridx_tmp_(num_rows) {
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(device_idx_));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(device_idx_));
#endif
ridx_segments_.emplace_back(NodePositionInfo{Segment(0, num_rows)});
thrust::sequence(thrust::device, ridx_.data(), ridx_.data() + ridx_.size());
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaStreamCreate(&stream_));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipStreamCreate(&stream_));
#endif
}
RowPartitioner::~RowPartitioner() {
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(device_idx_));
dh::safe_cuda(cudaStreamDestroy(stream_));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(device_idx_));
dh::safe_cuda(hipStreamDestroy(stream_));
#endif
}
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows(bst_node_t nidx) {

View File

@ -0,0 +1,4 @@
#if defined(XGBOOST_USE_HIP)
#include "row_partitioner.cu"
#endif