fix array_interface.h half type

This commit is contained in:
amdsc21 2023-05-02 20:50:50 +02:00
parent 65097212b3
commit b324d51f14
3 changed files with 3 additions and 3 deletions

View File

@ -60,7 +60,7 @@ void ElementWiseKernel(Context const* ctx, linalg::TensorView<T, D> t, Fn&& fn)
} }
ElementWiseKernelHost(t, ctx->Threads(), fn); ElementWiseKernelHost(t, ctx->Threads(), fn);
} }
#endif // !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_ #endif // !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
template <typename T, std::int32_t kDim> template <typename T, std::int32_t kDim>
auto cbegin(TensorView<T, kDim> const& v) { // NOLINT auto cbegin(TensorView<T, kDim> const& v) { // NOLINT

View File

@ -145,7 +145,7 @@ class Transform {
#if defined(XGBOOST_USE_HIP) #if defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(device_)); dh::safe_cuda(hipSetDevice(device_));
#else #elif defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(device_)); dh::safe_cuda(cudaSetDevice(device_));
#endif #endif

View File

@ -603,7 +603,7 @@ void DispatchDType(ArrayInterface<D> const array, std::int32_t device, Fn fn) {
}; };
switch (array.type) { switch (array.type) {
case ArrayInterfaceHandler::kF2: { case ArrayInterfaceHandler::kF2: {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 #if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(__HIP_PLATFORM_AMD__)
dispatch(__half{}); dispatch(__half{});
#endif #endif
break; break;