finish simple_dmatrix.cu
This commit is contained in:
parent
f0febfbcac
commit
53244bef6f
@ -19,7 +19,12 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int32_t /*nthread
|
||||
auto device = (adapter->DeviceIdx() < 0 || adapter->NumRows() == 0) ? dh::CurrentDevice()
|
||||
: adapter->DeviceIdx();
|
||||
CHECK_GE(device, 0);
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
dh::safe_cuda(cudaSetDevice(device));
|
||||
#elif defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipSetDevice(device));
|
||||
#endif
|
||||
|
||||
CHECK(adapter->NumRows() != kAdapterUnknownSize);
|
||||
CHECK(adapter->NumColumns() != kAdapterUnknownSize);
|
||||
|
||||
@ -9,19 +9,38 @@
|
||||
#include <thrust/scan.h>
|
||||
#include <thrust/execution_policy.h>
|
||||
#include "device_adapter.cuh"
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
#include "../common/device_helpers.cuh"
|
||||
#elif defined(XGBOOST_USE_HIP)
|
||||
#include "../common/device_helpers.hip.h"
|
||||
#endif
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
template <typename AdapterBatchT>
|
||||
struct COOToEntryOp {
|
||||
AdapterBatchT batch;
|
||||
|
||||
__device__ Entry operator()(size_t idx) {
|
||||
const auto& e = batch.GetElement(idx);
|
||||
return Entry(e.column_idx, e.value);
|
||||
}
|
||||
};
|
||||
#elif defined(XGBOOST_USE_HIP)
|
||||
template <typename AdapterBatchT>
|
||||
struct COOToEntryOp : thrust::unary_function<size_t, Entry> {
|
||||
AdapterBatchT batch;
|
||||
COOToEntryOp(AdapterBatchT batch): batch(batch) {};
|
||||
|
||||
__device__ Entry operator()(size_t idx) {
|
||||
const auto& e = batch.GetElement(idx);
|
||||
return Entry(e.column_idx, e.value);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
// Here the data is already correctly ordered and simply needs to be compacted
|
||||
// to remove missing data
|
||||
@ -44,7 +63,7 @@ void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipSetDevice(device_idx));
|
||||
#else
|
||||
#elif defined(XGBOOST_USE_CUDA)
|
||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||
#endif
|
||||
|
||||
@ -66,7 +85,7 @@ void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
|
||||
thrust::device_pointer_cast(offset.data()),
|
||||
thrust::device_pointer_cast(offset.data() + offset.size()),
|
||||
thrust::device_pointer_cast(offset.data()));
|
||||
#else
|
||||
#elif defined(XGBOOST_USE_CUDA)
|
||||
thrust::exclusive_scan(thrust::cuda::par(alloc),
|
||||
thrust::device_pointer_cast(offset.data()),
|
||||
thrust::device_pointer_cast(offset.data() + offset.size()),
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
#include "simple_dmatrix.cu"
|
||||
#endif
|
||||
Loading…
x
Reference in New Issue
Block a user