finish data.cu

This commit is contained in:
amdsc21
2023-03-10 05:00:57 +01:00
parent 713ab9e1a0
commit ccce4cf7e1
8 changed files with 59 additions and 16 deletions

View File

@@ -755,9 +755,9 @@ void MetaInfo::Validate(std::int32_t device) const {
}
}
#if !defined(XGBOOST_USE_CUDA)
#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
void MetaInfo::SetInfoFromCUDA(Context const&, StringView, Json) { common::AssertGPUSupport(); }
#endif // !defined(XGBOOST_USE_CUDA)
#endif // !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
using DMatrixThreadLocal =
dmlc::ThreadLocalStore<std::map<DMatrix const *, XGBAPIThreadLocalEntry>>;

View File

@@ -5,7 +5,13 @@
* \brief Handles setting metainfo from array interface.
*/
#include "../common/cuda_context.cuh"
#if defined(XGBOOST_USE_CUDA)
#include "../common/device_helpers.cuh"
#elif defined(XGBOOST_USE_HIP)
#include "../common/device_helpers.hip.h"
#endif
#include "../common/linalg_op.cuh"
#include "array_interface.h"
#include "device_adapter.cuh"
@@ -15,14 +21,22 @@
#include "xgboost/json.h"
#include "xgboost/logging.h"
#if defined(XGBOOST_USE_HIP)
namespace cub = hipcub;
#endif
namespace xgboost {
namespace {
auto SetDeviceToPtr(void const* ptr) {
#if defined(XGBOOST_USE_CUDA)
cudaPointerAttributes attr;
dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr));
int32_t ptr_device = attr.device;
dh::safe_cuda(cudaSetDevice(ptr_device));
return ptr_device;
#elif defined(XGBOOST_USE_HIP) /* this is wrong, need to figure out */
return 0;
#endif
}
template <typename T, int32_t D>
@@ -43,8 +57,14 @@ void CopyTensorInfoImpl(CUDAContext const* ctx, Json arr_interface, linalg::Tens
std::copy(array.shape, array.shape + D, shape.data());
// set data
data->Resize(array.n);
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaMemcpyAsync(data->DevicePointer(), array.data, array.n * sizeof(T),
cudaMemcpyDefault, ctx->Stream()));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipMemcpyAsync(data->DevicePointer(), array.data, array.n * sizeof(T),
hipMemcpyDefault, ctx->Stream()));
#endif
});
return;
}
@@ -94,8 +114,15 @@ void CopyQidImpl(ArrayInterface<1> array_interface, std::vector<bst_group_t>* p_
}
});
bool non_dec = true;
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaMemcpy(&non_dec, flag.data().get(), sizeof(bool),
cudaMemcpyDeviceToHost));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipMemcpy(&non_dec, flag.data().get(), sizeof(bool),
hipMemcpyDeviceToHost));
#endif
CHECK(non_dec) << "`qid` must be sorted in increasing order along with data.";
size_t bytes = 0;
dh::caching_device_vector<uint32_t> out(array_interface.Shape(0));
@@ -113,8 +140,15 @@ void CopyQidImpl(ArrayInterface<1> array_interface, std::vector<bst_group_t>* p_
group_ptr_.clear();
group_ptr_.resize(h_num_runs_out + 1, 0);
dh::XGBCachingDeviceAllocator<char> alloc;
#if defined(XGBOOST_USE_CUDA)
thrust::inclusive_scan(thrust::cuda::par(alloc), cnt.begin(),
cnt.begin() + h_num_runs_out, cnt.begin());
#elif defined(XGBOOST_USE_HIP)
thrust::inclusive_scan(thrust::hip::par(alloc), cnt.begin(),
cnt.begin() + h_num_runs_out, cnt.begin());
#endif
thrust::copy(cnt.begin(), cnt.begin() + h_num_runs_out,
group_ptr_.begin() + 1);
}

View File

@@ -0,0 +1,4 @@
#if defined(XGBOOST_USE_HIP)
#include "data.cu"
#endif