Support CUDA f16 without transformation. (#9207)

- Support f16 from cupy.
- Include CUDA header explicitly.
- Cleanup cmake nvtx support.
This commit is contained in:
Jiaming Yuan
2023-05-30 20:54:31 +08:00
committed by GitHub
parent 6f83d9c69a
commit 097f11b6e0
5 changed files with 27 additions and 54 deletions

View File

@@ -882,7 +882,7 @@ def _transform_cupy_array(data: DataType) -> CupyT:
if not hasattr(data, "__cuda_array_interface__") and hasattr(data, "__array__"):
data = cupy.array(data, copy=False)
if data.dtype.hasobject or data.dtype in [cupy.float16, cupy.bool_]:
if data.dtype.hasobject or data.dtype in [cupy.bool_]:
data = data.astype(cupy.float32, copy=False)
return data