fix array_interface.h half type

This commit is contained in:
amdsc21 2023-05-02 20:50:50 +02:00
parent 65097212b3
commit b324d51f14
3 changed files with 3 additions and 3 deletions

View File

@ -60,7 +60,7 @@ void ElementWiseKernel(Context const* ctx, linalg::TensorView<T, D> t, Fn&& fn)
}
ElementWiseKernelHost(t, ctx->Threads(), fn);
}
#endif // !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_
#endif // !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
template <typename T, std::int32_t kDim>
auto cbegin(TensorView<T, kDim> const& v) { // NOLINT

View File

@ -145,7 +145,7 @@ class Transform {
#if defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(device_));
#else
#elif defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(device_));
#endif

View File

@ -603,7 +603,7 @@ void DispatchDType(ArrayInterface<D> const array, std::int32_t device, Fn fn) {
};
switch (array.type) {
case ArrayInterfaceHandler::kF2: {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(__HIP_PLATFORM_AMD__)
dispatch(__half{});
#endif
break;