finish data.cu

This commit is contained in:
amdsc21 2023-03-10 05:00:57 +01:00
parent 713ab9e1a0
commit ccce4cf7e1
8 changed files with 59 additions and 16 deletions

View File

@ -4,7 +4,12 @@
#ifndef XGBOOST_COMMON_LINALG_OP_CUH_ #ifndef XGBOOST_COMMON_LINALG_OP_CUH_
#define XGBOOST_COMMON_LINALG_OP_CUH_ #define XGBOOST_COMMON_LINALG_OP_CUH_
#if defined(XGBOOST_USE_CUDA)
#include "device_helpers.cuh" #include "device_helpers.cuh"
#elif defined(XGBOOST_USE_HIP)
#include "device_helpers.hip.h"
#endif
#include "linalg_op.h" #include "linalg_op.h"
#include "xgboost/context.h" #include "xgboost/context.h"
#include "xgboost/linalg.h" #include "xgboost/linalg.h"
@ -14,13 +19,13 @@ namespace linalg {
template <typename T, int32_t D, typename Fn> template <typename T, int32_t D, typename Fn>
#if defined(XGBOOST_USE_HIP) #if defined(XGBOOST_USE_HIP)
void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, hipStream_t s = nullptr) void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, hipStream_t s = nullptr)
#else #elif defined(XGBOOST_USE_CUDA)
void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr)
#endif #endif
{ {
#if defined(XGBOOST_USE_HIP) #if defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(t.DeviceIdx())); dh::safe_cuda(hipSetDevice(t.DeviceIdx()));
#else #elif defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(t.DeviceIdx())); dh::safe_cuda(cudaSetDevice(t.DeviceIdx()));
#endif #endif
@ -40,7 +45,7 @@ void ElementWiseKernelDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s
template <typename T, int32_t D, typename Fn> template <typename T, int32_t D, typename Fn>
#if defined(XGBOOST_USE_HIP) #if defined(XGBOOST_USE_HIP)
void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, hipStream_t s = nullptr) void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, hipStream_t s = nullptr)
#else #elif defined(XGBOOST_USE_CUDA)
void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr)
#endif #endif
{ {

View File

@ -42,7 +42,7 @@ void ElementWiseKernelHost(linalg::TensorView<T, D> t, int32_t n_threads, Fn&& f
} }
} }
#if !defined(XGBOOST_USE_CUDA) #if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
template <typename T, int32_t D, typename Fn> template <typename T, int32_t D, typename Fn>
void ElementWiseKernelDevice(linalg::TensorView<T, D>, Fn&&, void* = nullptr) { void ElementWiseKernelDevice(linalg::TensorView<T, D>, Fn&&, void* = nullptr) {
common::AssertGPUSupport(); common::AssertGPUSupport();
@ -60,7 +60,7 @@ void ElementWiseKernel(Context const* ctx, linalg::TensorView<T, D> t, Fn&& fn)
} }
ElementWiseKernelHost(t, ctx->Threads(), fn); ElementWiseKernelHost(t, ctx->Threads(), fn);
} }
#endif // !defined(XGBOOST_USE_CUDA) #endif // !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_
template <typename T, std::int32_t kDim> template <typename T, std::int32_t kDim>
auto cbegin(TensorView<T, kDim> const& v) { // NOLINT auto cbegin(TensorView<T, kDim> const& v) { // NOLINT

View File

@ -755,9 +755,9 @@ void MetaInfo::Validate(std::int32_t device) const {
} }
} }
#if !defined(XGBOOST_USE_CUDA) #if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
void MetaInfo::SetInfoFromCUDA(Context const&, StringView, Json) { common::AssertGPUSupport(); } void MetaInfo::SetInfoFromCUDA(Context const&, StringView, Json) { common::AssertGPUSupport(); }
#endif // !defined(XGBOOST_USE_CUDA) #endif // !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
using DMatrixThreadLocal = using DMatrixThreadLocal =
dmlc::ThreadLocalStore<std::map<DMatrix const *, XGBAPIThreadLocalEntry>>; dmlc::ThreadLocalStore<std::map<DMatrix const *, XGBAPIThreadLocalEntry>>;

View File

@ -5,7 +5,13 @@
* \brief Handles setting metainfo from array interface. * \brief Handles setting metainfo from array interface.
*/ */
#include "../common/cuda_context.cuh" #include "../common/cuda_context.cuh"
#if defined(XGBOOST_USE_CUDA)
#include "../common/device_helpers.cuh" #include "../common/device_helpers.cuh"
#elif defined(XGBOOST_USE_HIP)
#include "../common/device_helpers.hip.h"
#endif
#include "../common/linalg_op.cuh" #include "../common/linalg_op.cuh"
#include "array_interface.h" #include "array_interface.h"
#include "device_adapter.cuh" #include "device_adapter.cuh"
@ -15,14 +21,22 @@
#include "xgboost/json.h" #include "xgboost/json.h"
#include "xgboost/logging.h" #include "xgboost/logging.h"
#if defined(XGBOOST_USE_HIP)
namespace cub = hipcub;
#endif
namespace xgboost { namespace xgboost {
namespace { namespace {
auto SetDeviceToPtr(void const* ptr) { auto SetDeviceToPtr(void const* ptr) {
#if defined(XGBOOST_USE_CUDA)
cudaPointerAttributes attr; cudaPointerAttributes attr;
dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr)); dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr));
int32_t ptr_device = attr.device; int32_t ptr_device = attr.device;
dh::safe_cuda(cudaSetDevice(ptr_device)); dh::safe_cuda(cudaSetDevice(ptr_device));
return ptr_device; return ptr_device;
#elif defined(XGBOOST_USE_HIP) /* this is wrong, need to figure out */
return 0;
#endif
} }
template <typename T, int32_t D> template <typename T, int32_t D>
@ -43,8 +57,14 @@ void CopyTensorInfoImpl(CUDAContext const* ctx, Json arr_interface, linalg::Tens
std::copy(array.shape, array.shape + D, shape.data()); std::copy(array.shape, array.shape + D, shape.data());
// set data // set data
data->Resize(array.n); data->Resize(array.n);
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaMemcpyAsync(data->DevicePointer(), array.data, array.n * sizeof(T), dh::safe_cuda(cudaMemcpyAsync(data->DevicePointer(), array.data, array.n * sizeof(T),
cudaMemcpyDefault, ctx->Stream())); cudaMemcpyDefault, ctx->Stream()));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipMemcpyAsync(data->DevicePointer(), array.data, array.n * sizeof(T),
hipMemcpyDefault, ctx->Stream()));
#endif
}); });
return; return;
} }
@ -94,8 +114,15 @@ void CopyQidImpl(ArrayInterface<1> array_interface, std::vector<bst_group_t>* p_
} }
}); });
bool non_dec = true; bool non_dec = true;
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaMemcpy(&non_dec, flag.data().get(), sizeof(bool), dh::safe_cuda(cudaMemcpy(&non_dec, flag.data().get(), sizeof(bool),
cudaMemcpyDeviceToHost)); cudaMemcpyDeviceToHost));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipMemcpy(&non_dec, flag.data().get(), sizeof(bool),
hipMemcpyDeviceToHost));
#endif
CHECK(non_dec) << "`qid` must be sorted in increasing order along with data."; CHECK(non_dec) << "`qid` must be sorted in increasing order along with data.";
size_t bytes = 0; size_t bytes = 0;
dh::caching_device_vector<uint32_t> out(array_interface.Shape(0)); dh::caching_device_vector<uint32_t> out(array_interface.Shape(0));
@ -113,8 +140,15 @@ void CopyQidImpl(ArrayInterface<1> array_interface, std::vector<bst_group_t>* p_
group_ptr_.clear(); group_ptr_.clear();
group_ptr_.resize(h_num_runs_out + 1, 0); group_ptr_.resize(h_num_runs_out + 1, 0);
dh::XGBCachingDeviceAllocator<char> alloc; dh::XGBCachingDeviceAllocator<char> alloc;
#if defined(XGBOOST_USE_CUDA)
thrust::inclusive_scan(thrust::cuda::par(alloc), cnt.begin(), thrust::inclusive_scan(thrust::cuda::par(alloc), cnt.begin(),
cnt.begin() + h_num_runs_out, cnt.begin()); cnt.begin() + h_num_runs_out, cnt.begin());
#elif defined(XGBOOST_USE_HIP)
thrust::inclusive_scan(thrust::hip::par(alloc), cnt.begin(),
cnt.begin() + h_num_runs_out, cnt.begin());
#endif
thrust::copy(cnt.begin(), cnt.begin() + h_num_runs_out, thrust::copy(cnt.begin(), cnt.begin() + h_num_runs_out,
group_ptr_.begin() + 1); group_ptr_.begin() + 1);
} }

View File

@ -0,0 +1,4 @@
#if defined(XGBOOST_USE_HIP)
#include "data.cu"
#endif

View File

@ -13,6 +13,6 @@ DMLC_REGISTRY_FILE_TAG(quantile_obj);
} // namespace obj } // namespace obj
} // namespace xgboost } // namespace xgboost
#ifndef XGBOOST_USE_CUDA #if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
#include "quantile_obj.cu" #include "quantile_obj.cu"
#endif // !defined(XBGOOST_USE_CUDA) #endif // !defined(XBGOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)

View File

@ -19,7 +19,7 @@
#include "xgboost/objective.h" // ObjFunction #include "xgboost/objective.h" // ObjFunction
#include "xgboost/parameter.h" // XGBoostParameter #include "xgboost/parameter.h" // XGBoostParameter
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
#include "../common/linalg_op.cuh" // ElementWiseKernel #include "../common/linalg_op.cuh" // ElementWiseKernel
#include "../common/stats.cuh" // SegmentedQuantile #include "../common/stats.cuh" // SegmentedQuantile
@ -123,7 +123,7 @@ class QuantileRegression : public ObjFunction {
} }
} }
} else { } else {
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
alpha_.SetDevice(ctx_->gpu_id); alpha_.SetDevice(ctx_->gpu_id);
auto d_alpha = alpha_.ConstDeviceSpan(); auto d_alpha = alpha_.ConstDeviceSpan();
auto d_labels = info.labels.View(ctx_->gpu_id); auto d_labels = info.labels.View(ctx_->gpu_id);
@ -158,7 +158,7 @@ class QuantileRegression : public ObjFunction {
} }
#else #else
common::AssertGPUSupport(); common::AssertGPUSupport();
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
} }
// For multiple quantiles, we should extend the base score to a vector instead of // For multiple quantiles, we should extend the base score to a vector instead of
@ -215,8 +215,8 @@ XGBOOST_REGISTER_OBJECTIVE(QuantileRegression, QuantileRegression::Name())
.describe("Regression with quantile loss.") .describe("Regression with quantile loss.")
.set_body([]() { return new QuantileRegression(); }); .set_body([]() { return new QuantileRegression(); });
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
DMLC_REGISTRY_FILE_TAG(quantile_obj_gpu); DMLC_REGISTRY_FILE_TAG(quantile_obj_gpu);
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
} // namespace obj } // namespace obj
} // namespace xgboost } // namespace xgboost

View File

@ -13,6 +13,6 @@ DMLC_REGISTRY_FILE_TAG(regression_obj);
} // namespace obj } // namespace obj
} // namespace xgboost } // namespace xgboost
#ifndef XGBOOST_USE_CUDA #if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
#include "regression_obj.cu" #include "regression_obj.cu"
#endif // XGBOOST_USE_CUDA #endif // XGBOOST_USE_CUDA && defined(XGBOOST_USE_HIP)