fix array_interface.h half type
This commit is contained in:
parent
65097212b3
commit
b324d51f14
@ -60,7 +60,7 @@ void ElementWiseKernel(Context const* ctx, linalg::TensorView<T, D> t, Fn&& fn)
|
|||||||
}
|
}
|
||||||
ElementWiseKernelHost(t, ctx->Threads(), fn);
|
ElementWiseKernelHost(t, ctx->Threads(), fn);
|
||||||
}
|
}
|
||||||
#endif // !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_
|
#endif // !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
|
||||||
|
|
||||||
template <typename T, std::int32_t kDim>
|
template <typename T, std::int32_t kDim>
|
||||||
auto cbegin(TensorView<T, kDim> const& v) { // NOLINT
|
auto cbegin(TensorView<T, kDim> const& v) { // NOLINT
|
||||||
|
|||||||
@ -145,7 +145,7 @@ class Transform {
|
|||||||
|
|
||||||
#if defined(XGBOOST_USE_HIP)
|
#if defined(XGBOOST_USE_HIP)
|
||||||
dh::safe_cuda(hipSetDevice(device_));
|
dh::safe_cuda(hipSetDevice(device_));
|
||||||
#else
|
#elif defined(XGBOOST_USE_CUDA)
|
||||||
dh::safe_cuda(cudaSetDevice(device_));
|
dh::safe_cuda(cudaSetDevice(device_));
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|||||||
@ -603,7 +603,7 @@ void DispatchDType(ArrayInterface<D> const array, std::int32_t device, Fn fn) {
|
|||||||
};
|
};
|
||||||
switch (array.type) {
|
switch (array.type) {
|
||||||
case ArrayInterfaceHandler::kF2: {
|
case ArrayInterfaceHandler::kF2: {
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
|
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(__HIP_PLATFORM_AMD__)
|
||||||
dispatch(__half{});
|
dispatch(__half{});
|
||||||
#endif
|
#endif
|
||||||
break;
|
break;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user