From 39c637ee199303d1c84ef25e55d2651d83c217e5 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 8 Dec 2023 03:42:14 +0800 Subject: [PATCH] Use array interface in Python prediction return. (#9855) --- include/xgboost/linalg.h | 4 +- python-package/xgboost/compat.py | 10 ++++ python-package/xgboost/core.py | 81 ++++++++++++++------------------ src/common/device_helpers.cuh | 2 + 4 files changed, 49 insertions(+), 48 deletions(-) diff --git a/include/xgboost/linalg.h b/include/xgboost/linalg.h index 8806818fb..581b2f080 100644 --- a/include/xgboost/linalg.h +++ b/include/xgboost/linalg.h @@ -683,7 +683,7 @@ using MatrixView = TensorView; * * `stream` is optionally included when data is on CUDA device. */ -template +template Json ArrayInterface(TensorView const &t) { Json array_interface{Object{}}; array_interface["data"] = std::vector(2); @@ -691,7 +691,7 @@ Json ArrayInterface(TensorView const &t) { array_interface["data"][1] = Boolean{true}; if (t.Device().IsCUDA()) { // Change this once we have different CUDA stream. - array_interface["stream"] = Null{}; + array_interface["stream"] = Integer{2}; } std::vector shape(t.Shape().size()); std::vector stride(t.Stride().size()); diff --git a/python-package/xgboost/compat.py b/python-package/xgboost/compat.py index c40dea5fd..7c11495f7 100644 --- a/python-package/xgboost/compat.py +++ b/python-package/xgboost/compat.py @@ -100,6 +100,16 @@ def is_cupy_available() -> bool: return False +def import_cupy() -> types.ModuleType: + """Import cupy.""" + if not is_cupy_available(): + raise ImportError("`cupy` is required for handling CUDA buffer.") + + import cupy # pylint: disable=import-error + + return cupy + + try: import scipy.sparse as scipy_sparse from scipy.sparse import csr_matrix as scipy_csr diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index bfc94aa04..ffc7db8dd 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -3,7 +3,6 @@ """Core XGBoost Library.""" import copy import ctypes -import importlib.util import json import os import re @@ -45,7 +44,6 @@ from ._typing import ( CStrPptr, CStrPtr, CTypeT, - CupyT, DataType, FeatureInfo, FeatureNames, @@ -55,7 +53,7 @@ from ._typing import ( TransformedData, c_bst_ulong, ) -from .compat import PANDAS_INSTALLED, DataFrame, py_str +from .compat import PANDAS_INSTALLED, DataFrame, import_cupy, py_str from .libpath import find_lib_path @@ -380,34 +378,6 @@ def ctypes2numpy(cptr: CNumericPtr, length: int, dtype: Type[np.number]) -> np.n return res -def ctypes2cupy(cptr: CNumericPtr, length: int, dtype: Type[np.number]) -> CupyT: - """Convert a ctypes pointer array to a cupy array.""" - # pylint: disable=import-error - import cupy - from cupy.cuda.memory import MemoryPointer, UnownedMemory - - CUPY_TO_CTYPES_MAPPING: Dict[Type[np.number], Type[CNumeric]] = { - cupy.float32: ctypes.c_float, - cupy.uint32: ctypes.c_uint, - } - if dtype not in CUPY_TO_CTYPES_MAPPING: - raise RuntimeError(f"Supported types: {CUPY_TO_CTYPES_MAPPING.keys()}") - addr = ctypes.cast(cptr, ctypes.c_void_p).value - # pylint: disable=c-extension-no-member,no-member - device = cupy.cuda.runtime.pointerGetAttributes(addr).device - # The owner field is just used to keep the memory alive with ref count. As - # unowned's life time is scoped within this function we don't need that. - unownd = UnownedMemory( - addr, length * ctypes.sizeof(CUPY_TO_CTYPES_MAPPING[dtype]), owner=None - ) - memptr = MemoryPointer(unownd, 0) - # pylint: disable=unexpected-keyword-arg - mem = cupy.ndarray((length,), dtype=dtype, memptr=memptr) - assert mem.device.id == device - arr = cupy.array(mem, copy=True) - return arr - - def ctypes2buffer(cptr: CStrPtr, length: int) -> bytearray: """Convert ctypes pointer to buffer type.""" if not isinstance(cptr, ctypes.POINTER(ctypes.c_char)): @@ -466,14 +436,8 @@ def from_array_interface(interface: dict) -> NumpyOrCupy: if "stream" in interface: # CUDA stream is presented, this is a __cuda_array_interface__. - spec = importlib.util.find_spec("cupy") - if spec is None: - raise ImportError("`cupy` is required for handling CUDA buffer.") - - import cupy as cp # pylint: disable=import-error - arr.__cuda_array_interface__ = interface - out = cp.array(arr, copy=True) + out = import_cupy().array(arr, copy=True) else: arr.__array_interface__ = interface out = np.array(arr, copy=True) @@ -481,17 +445,42 @@ def from_array_interface(interface: dict) -> NumpyOrCupy: return out +def make_array_interface( + ptr: CNumericPtr, shape: Tuple[int, ...], dtype: Type[np.number], is_cuda: bool +) -> Dict[str, Union[int, tuple, None]]: + """Make an __(cuda)_array_interface__ from a pointer.""" + # Use an empty array to handle typestr and descr + if is_cuda: + empty = import_cupy().empty(shape=(0,), dtype=dtype) + array = empty.__cuda_array_interface__ # pylint: disable=no-member + else: + empty = np.empty(shape=(0,), dtype=dtype) + array = empty.__array_interface__ # pylint: disable=no-member + + addr = ctypes.cast(ptr, ctypes.c_void_p).value + length = int(np.prod(shape)) + # Handle empty dataset. + assert addr is not None or length == 0 + + if addr is None: + return array + + array["data"] = (addr, True) + if is_cuda: + array["stream"] = 2 + array["shape"] = shape + array["strides"] = None + return array + + def _prediction_output( shape: CNumericPtr, dims: c_bst_ulong, predts: CFloatPtr, is_cuda: bool ) -> NumpyOrCupy: - arr_shape = ctypes2numpy(shape, dims.value, np.uint64) - length = int(np.prod(arr_shape)) - if is_cuda: - arr_predict = ctypes2cupy(predts, length, np.float32) - else: - arr_predict = ctypes2numpy(predts, length, np.float32) - arr_predict = arr_predict.reshape(arr_shape) - return arr_predict + arr_shape = tuple(ctypes2numpy(shape, dims.value, np.uint64).flatten()) + array = from_array_interface( + make_array_interface(predts, arr_shape, np.float32, is_cuda) + ) + return array class DataIter(ABC): # pylint: disable=too-many-instance-attributes diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index ffe61800e..46f76c415 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -1097,6 +1097,8 @@ inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT dh::safe_cuda(cudaEventRecord(event_, cudaStream_t{stream})); } +// Changing this has effect on prediction return, where we need to pass the pointer to +// third-party libraries like cuPy inline CUDAStreamView DefaultStream() { #ifdef CUDA_API_PER_THREAD_DEFAULT_STREAM return CUDAStreamView{cudaStreamPerThread};