fix array_interface.h half type
This commit is contained in:
parent
65097212b3
commit
b324d51f14
@ -60,7 +60,7 @@ void ElementWiseKernel(Context const* ctx, linalg::TensorView<T, D> t, Fn&& fn)
|
||||
}
|
||||
ElementWiseKernelHost(t, ctx->Threads(), fn);
|
||||
}
|
||||
#endif // !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_
|
||||
#endif // !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
|
||||
|
||||
template <typename T, std::int32_t kDim>
|
||||
auto cbegin(TensorView<T, kDim> const& v) { // NOLINT
|
||||
|
||||
@ -145,7 +145,7 @@ class Transform {
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipSetDevice(device_));
|
||||
#else
|
||||
#elif defined(XGBOOST_USE_CUDA)
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
#endif
|
||||
|
||||
|
||||
@ -603,7 +603,7 @@ void DispatchDType(ArrayInterface<D> const array, std::int32_t device, Fn fn) {
|
||||
};
|
||||
switch (array.type) {
|
||||
case ArrayInterfaceHandler::kF2: {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
|
||||
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(__HIP_PLATFORM_AMD__)
|
||||
dispatch(__half{});
|
||||
#endif
|
||||
break;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user