finish simple_dmatrix.cu

This commit is contained in:
amdsc21 2023-03-10 03:38:09 +01:00
parent f0febfbcac
commit 53244bef6f
3 changed files with 30 additions and 2 deletions

View File

@ -19,7 +19,12 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int32_t /*nthread
auto device = (adapter->DeviceIdx() < 0 || adapter->NumRows() == 0) ? dh::CurrentDevice() auto device = (adapter->DeviceIdx() < 0 || adapter->NumRows() == 0) ? dh::CurrentDevice()
: adapter->DeviceIdx(); : adapter->DeviceIdx();
CHECK_GE(device, 0); CHECK_GE(device, 0);
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(device)); dh::safe_cuda(cudaSetDevice(device));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(device));
#endif
CHECK(adapter->NumRows() != kAdapterUnknownSize); CHECK(adapter->NumRows() != kAdapterUnknownSize);
CHECK(adapter->NumColumns() != kAdapterUnknownSize); CHECK(adapter->NumColumns() != kAdapterUnknownSize);

View File

@ -9,19 +9,38 @@
#include <thrust/scan.h> #include <thrust/scan.h>
#include <thrust/execution_policy.h> #include <thrust/execution_policy.h>
#include "device_adapter.cuh" #include "device_adapter.cuh"
#if defined(XGBOOST_USE_CUDA)
#include "../common/device_helpers.cuh" #include "../common/device_helpers.cuh"
#elif defined(XGBOOST_USE_HIP)
#include "../common/device_helpers.hip.h"
#endif
namespace xgboost { namespace xgboost {
namespace data { namespace data {
#if defined(XGBOOST_USE_CUDA)
template <typename AdapterBatchT> template <typename AdapterBatchT>
struct COOToEntryOp { struct COOToEntryOp {
AdapterBatchT batch; AdapterBatchT batch;
__device__ Entry operator()(size_t idx) { __device__ Entry operator()(size_t idx) {
const auto& e = batch.GetElement(idx); const auto& e = batch.GetElement(idx);
return Entry(e.column_idx, e.value); return Entry(e.column_idx, e.value);
} }
}; };
#elif defined(XGBOOST_USE_HIP)
template <typename AdapterBatchT>
struct COOToEntryOp : thrust::unary_function<size_t, Entry> {
AdapterBatchT batch;
COOToEntryOp(AdapterBatchT batch): batch(batch) {};
__device__ Entry operator()(size_t idx) {
const auto& e = batch.GetElement(idx);
return Entry(e.column_idx, e.value);
}
};
#endif
// Here the data is already correctly ordered and simply needs to be compacted // Here the data is already correctly ordered and simply needs to be compacted
// to remove missing data // to remove missing data
@ -44,7 +63,7 @@ void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
#if defined(XGBOOST_USE_HIP) #if defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(device_idx)); dh::safe_cuda(hipSetDevice(device_idx));
#else #elif defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(device_idx)); dh::safe_cuda(cudaSetDevice(device_idx));
#endif #endif
@ -66,7 +85,7 @@ void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
thrust::device_pointer_cast(offset.data()), thrust::device_pointer_cast(offset.data()),
thrust::device_pointer_cast(offset.data() + offset.size()), thrust::device_pointer_cast(offset.data() + offset.size()),
thrust::device_pointer_cast(offset.data())); thrust::device_pointer_cast(offset.data()));
#else #elif defined(XGBOOST_USE_CUDA)
thrust::exclusive_scan(thrust::cuda::par(alloc), thrust::exclusive_scan(thrust::cuda::par(alloc),
thrust::device_pointer_cast(offset.data()), thrust::device_pointer_cast(offset.data()),
thrust::device_pointer_cast(offset.data() + offset.size()), thrust::device_pointer_cast(offset.data() + offset.size()),

View File

@ -0,0 +1,4 @@
#if defined(XGBOOST_USE_HIP)
#include "simple_dmatrix.cu"
#endif