Use array interface in Python prediction return. (#9855)

This commit is contained in:
Jiaming Yuan 2023-12-08 03:42:14 +08:00 committed by GitHub
parent 2c0fc97306
commit 39c637ee19
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 49 additions and 48 deletions

View File

@ -683,7 +683,7 @@ using MatrixView = TensorView<T, 2>;
* *
* `stream` is optionally included when data is on CUDA device. * `stream` is optionally included when data is on CUDA device.
*/ */
template <typename T, int32_t D> template <typename T, std::int32_t D>
Json ArrayInterface(TensorView<T const, D> const &t) { Json ArrayInterface(TensorView<T const, D> const &t) {
Json array_interface{Object{}}; Json array_interface{Object{}};
array_interface["data"] = std::vector<Json>(2); array_interface["data"] = std::vector<Json>(2);
@ -691,7 +691,7 @@ Json ArrayInterface(TensorView<T const, D> const &t) {
array_interface["data"][1] = Boolean{true}; array_interface["data"][1] = Boolean{true};
if (t.Device().IsCUDA()) { if (t.Device().IsCUDA()) {
// Change this once we have different CUDA stream. // Change this once we have different CUDA stream.
array_interface["stream"] = Null{}; array_interface["stream"] = Integer{2};
} }
std::vector<Json> shape(t.Shape().size()); std::vector<Json> shape(t.Shape().size());
std::vector<Json> stride(t.Stride().size()); std::vector<Json> stride(t.Stride().size());

View File

@ -100,6 +100,16 @@ def is_cupy_available() -> bool:
return False return False
def import_cupy() -> types.ModuleType:
"""Import cupy."""
if not is_cupy_available():
raise ImportError("`cupy` is required for handling CUDA buffer.")
import cupy # pylint: disable=import-error
return cupy
try: try:
import scipy.sparse as scipy_sparse import scipy.sparse as scipy_sparse
from scipy.sparse import csr_matrix as scipy_csr from scipy.sparse import csr_matrix as scipy_csr

View File

@ -3,7 +3,6 @@
"""Core XGBoost Library.""" """Core XGBoost Library."""
import copy import copy
import ctypes import ctypes
import importlib.util
import json import json
import os import os
import re import re
@ -45,7 +44,6 @@ from ._typing import (
CStrPptr, CStrPptr,
CStrPtr, CStrPtr,
CTypeT, CTypeT,
CupyT,
DataType, DataType,
FeatureInfo, FeatureInfo,
FeatureNames, FeatureNames,
@ -55,7 +53,7 @@ from ._typing import (
TransformedData, TransformedData,
c_bst_ulong, c_bst_ulong,
) )
from .compat import PANDAS_INSTALLED, DataFrame, py_str from .compat import PANDAS_INSTALLED, DataFrame, import_cupy, py_str
from .libpath import find_lib_path from .libpath import find_lib_path
@ -380,34 +378,6 @@ def ctypes2numpy(cptr: CNumericPtr, length: int, dtype: Type[np.number]) -> np.n
return res return res
def ctypes2cupy(cptr: CNumericPtr, length: int, dtype: Type[np.number]) -> CupyT:
"""Convert a ctypes pointer array to a cupy array."""
# pylint: disable=import-error
import cupy
from cupy.cuda.memory import MemoryPointer, UnownedMemory
CUPY_TO_CTYPES_MAPPING: Dict[Type[np.number], Type[CNumeric]] = {
cupy.float32: ctypes.c_float,
cupy.uint32: ctypes.c_uint,
}
if dtype not in CUPY_TO_CTYPES_MAPPING:
raise RuntimeError(f"Supported types: {CUPY_TO_CTYPES_MAPPING.keys()}")
addr = ctypes.cast(cptr, ctypes.c_void_p).value
# pylint: disable=c-extension-no-member,no-member
device = cupy.cuda.runtime.pointerGetAttributes(addr).device
# The owner field is just used to keep the memory alive with ref count. As
# unowned's life time is scoped within this function we don't need that.
unownd = UnownedMemory(
addr, length * ctypes.sizeof(CUPY_TO_CTYPES_MAPPING[dtype]), owner=None
)
memptr = MemoryPointer(unownd, 0)
# pylint: disable=unexpected-keyword-arg
mem = cupy.ndarray((length,), dtype=dtype, memptr=memptr)
assert mem.device.id == device
arr = cupy.array(mem, copy=True)
return arr
def ctypes2buffer(cptr: CStrPtr, length: int) -> bytearray: def ctypes2buffer(cptr: CStrPtr, length: int) -> bytearray:
"""Convert ctypes pointer to buffer type.""" """Convert ctypes pointer to buffer type."""
if not isinstance(cptr, ctypes.POINTER(ctypes.c_char)): if not isinstance(cptr, ctypes.POINTER(ctypes.c_char)):
@ -466,14 +436,8 @@ def from_array_interface(interface: dict) -> NumpyOrCupy:
if "stream" in interface: if "stream" in interface:
# CUDA stream is presented, this is a __cuda_array_interface__. # CUDA stream is presented, this is a __cuda_array_interface__.
spec = importlib.util.find_spec("cupy")
if spec is None:
raise ImportError("`cupy` is required for handling CUDA buffer.")
import cupy as cp # pylint: disable=import-error
arr.__cuda_array_interface__ = interface arr.__cuda_array_interface__ = interface
out = cp.array(arr, copy=True) out = import_cupy().array(arr, copy=True)
else: else:
arr.__array_interface__ = interface arr.__array_interface__ = interface
out = np.array(arr, copy=True) out = np.array(arr, copy=True)
@ -481,17 +445,42 @@ def from_array_interface(interface: dict) -> NumpyOrCupy:
return out return out
def make_array_interface(
ptr: CNumericPtr, shape: Tuple[int, ...], dtype: Type[np.number], is_cuda: bool
) -> Dict[str, Union[int, tuple, None]]:
"""Make an __(cuda)_array_interface__ from a pointer."""
# Use an empty array to handle typestr and descr
if is_cuda:
empty = import_cupy().empty(shape=(0,), dtype=dtype)
array = empty.__cuda_array_interface__ # pylint: disable=no-member
else:
empty = np.empty(shape=(0,), dtype=dtype)
array = empty.__array_interface__ # pylint: disable=no-member
addr = ctypes.cast(ptr, ctypes.c_void_p).value
length = int(np.prod(shape))
# Handle empty dataset.
assert addr is not None or length == 0
if addr is None:
return array
array["data"] = (addr, True)
if is_cuda:
array["stream"] = 2
array["shape"] = shape
array["strides"] = None
return array
def _prediction_output( def _prediction_output(
shape: CNumericPtr, dims: c_bst_ulong, predts: CFloatPtr, is_cuda: bool shape: CNumericPtr, dims: c_bst_ulong, predts: CFloatPtr, is_cuda: bool
) -> NumpyOrCupy: ) -> NumpyOrCupy:
arr_shape = ctypes2numpy(shape, dims.value, np.uint64) arr_shape = tuple(ctypes2numpy(shape, dims.value, np.uint64).flatten())
length = int(np.prod(arr_shape)) array = from_array_interface(
if is_cuda: make_array_interface(predts, arr_shape, np.float32, is_cuda)
arr_predict = ctypes2cupy(predts, length, np.float32) )
else: return array
arr_predict = ctypes2numpy(predts, length, np.float32)
arr_predict = arr_predict.reshape(arr_shape)
return arr_predict
class DataIter(ABC): # pylint: disable=too-many-instance-attributes class DataIter(ABC): # pylint: disable=too-many-instance-attributes

View File

@ -1097,6 +1097,8 @@ inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT
dh::safe_cuda(cudaEventRecord(event_, cudaStream_t{stream})); dh::safe_cuda(cudaEventRecord(event_, cudaStream_t{stream}));
} }
// Changing this has effect on prediction return, where we need to pass the pointer to
// third-party libraries like cuPy
inline CUDAStreamView DefaultStream() { inline CUDAStreamView DefaultStream() {
#ifdef CUDA_API_PER_THREAD_DEFAULT_STREAM #ifdef CUDA_API_PER_THREAD_DEFAULT_STREAM
return CUDAStreamView{cudaStreamPerThread}; return CUDAStreamView{cudaStreamPerThread};