Support half type from cupy. (#8487)

This commit is contained in:
Jiaming Yuan
2022-11-30 17:56:42 +08:00
committed by GitHub
parent addaa63732
commit 157e98edf7
3 changed files with 36 additions and 4 deletions

View File

@@ -42,6 +42,8 @@ def dmatrix_from_cupy(input_type, DMatrixT, missing=np.NAN):
def _test_from_cupy(DMatrixT):
'''Test constructing DMatrix from cupy'''
import cupy as cp
dmatrix_from_cupy(np.float16, DMatrixT, np.NAN)
dmatrix_from_cupy(np.float32, DMatrixT, np.NAN)
dmatrix_from_cupy(np.float64, DMatrixT, np.NAN)