finish row_partitioner.cu
This commit is contained in:
parent
495816f694
commit
500428cc0f
@ -7,7 +7,12 @@
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
#include "../../common/device_helpers.cuh"
|
#include "../../common/device_helpers.cuh"
|
||||||
|
#elif defined(XGBOOST_USE_HIP)
|
||||||
|
#include "../../common/device_helpers.hip.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#include "row_partitioner.cuh"
|
#include "row_partitioner.cuh"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -15,15 +20,31 @@ namespace tree {
|
|||||||
|
|
||||||
RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)
|
RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)
|
||||||
: device_idx_(device_idx), ridx_(num_rows), ridx_tmp_(num_rows) {
|
: device_idx_(device_idx), ridx_(num_rows), ridx_tmp_(num_rows) {
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
dh::safe_cuda(cudaSetDevice(device_idx_));
|
dh::safe_cuda(cudaSetDevice(device_idx_));
|
||||||
|
#elif defined(XGBOOST_USE_HIP)
|
||||||
|
dh::safe_cuda(hipSetDevice(device_idx_));
|
||||||
|
#endif
|
||||||
|
|
||||||
ridx_segments_.emplace_back(NodePositionInfo{Segment(0, num_rows)});
|
ridx_segments_.emplace_back(NodePositionInfo{Segment(0, num_rows)});
|
||||||
thrust::sequence(thrust::device, ridx_.data(), ridx_.data() + ridx_.size());
|
thrust::sequence(thrust::device, ridx_.data(), ridx_.data() + ridx_.size());
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
dh::safe_cuda(cudaStreamCreate(&stream_));
|
dh::safe_cuda(cudaStreamCreate(&stream_));
|
||||||
|
#elif defined(XGBOOST_USE_HIP)
|
||||||
|
dh::safe_cuda(hipStreamCreate(&stream_));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
RowPartitioner::~RowPartitioner() {
|
RowPartitioner::~RowPartitioner() {
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
dh::safe_cuda(cudaSetDevice(device_idx_));
|
dh::safe_cuda(cudaSetDevice(device_idx_));
|
||||||
dh::safe_cuda(cudaStreamDestroy(stream_));
|
dh::safe_cuda(cudaStreamDestroy(stream_));
|
||||||
|
#elif defined(XGBOOST_USE_HIP)
|
||||||
|
dh::safe_cuda(hipSetDevice(device_idx_));
|
||||||
|
dh::safe_cuda(hipStreamDestroy(stream_));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows(bst_node_t nidx) {
|
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows(bst_node_t nidx) {
|
||||||
|
|||||||
@ -0,0 +1,4 @@
|
|||||||
|
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
#include "row_partitioner.cu"
|
||||||
|
#endif
|
||||||
Loading…
x
Reference in New Issue
Block a user