Use array interface in Python prediction return. (#9855)
This commit is contained in:
parent
2c0fc97306
commit
39c637ee19
@ -683,7 +683,7 @@ using MatrixView = TensorView<T, 2>;
|
||||
*
|
||||
* `stream` is optionally included when data is on CUDA device.
|
||||
*/
|
||||
template <typename T, int32_t D>
|
||||
template <typename T, std::int32_t D>
|
||||
Json ArrayInterface(TensorView<T const, D> const &t) {
|
||||
Json array_interface{Object{}};
|
||||
array_interface["data"] = std::vector<Json>(2);
|
||||
@ -691,7 +691,7 @@ Json ArrayInterface(TensorView<T const, D> const &t) {
|
||||
array_interface["data"][1] = Boolean{true};
|
||||
if (t.Device().IsCUDA()) {
|
||||
// Change this once we have different CUDA stream.
|
||||
array_interface["stream"] = Null{};
|
||||
array_interface["stream"] = Integer{2};
|
||||
}
|
||||
std::vector<Json> shape(t.Shape().size());
|
||||
std::vector<Json> stride(t.Stride().size());
|
||||
|
||||
@ -100,6 +100,16 @@ def is_cupy_available() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def import_cupy() -> types.ModuleType:
|
||||
"""Import cupy."""
|
||||
if not is_cupy_available():
|
||||
raise ImportError("`cupy` is required for handling CUDA buffer.")
|
||||
|
||||
import cupy # pylint: disable=import-error
|
||||
|
||||
return cupy
|
||||
|
||||
|
||||
try:
|
||||
import scipy.sparse as scipy_sparse
|
||||
from scipy.sparse import csr_matrix as scipy_csr
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
"""Core XGBoost Library."""
|
||||
import copy
|
||||
import ctypes
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
@ -45,7 +44,6 @@ from ._typing import (
|
||||
CStrPptr,
|
||||
CStrPtr,
|
||||
CTypeT,
|
||||
CupyT,
|
||||
DataType,
|
||||
FeatureInfo,
|
||||
FeatureNames,
|
||||
@ -55,7 +53,7 @@ from ._typing import (
|
||||
TransformedData,
|
||||
c_bst_ulong,
|
||||
)
|
||||
from .compat import PANDAS_INSTALLED, DataFrame, py_str
|
||||
from .compat import PANDAS_INSTALLED, DataFrame, import_cupy, py_str
|
||||
from .libpath import find_lib_path
|
||||
|
||||
|
||||
@ -380,34 +378,6 @@ def ctypes2numpy(cptr: CNumericPtr, length: int, dtype: Type[np.number]) -> np.n
|
||||
return res
|
||||
|
||||
|
||||
def ctypes2cupy(cptr: CNumericPtr, length: int, dtype: Type[np.number]) -> CupyT:
|
||||
"""Convert a ctypes pointer array to a cupy array."""
|
||||
# pylint: disable=import-error
|
||||
import cupy
|
||||
from cupy.cuda.memory import MemoryPointer, UnownedMemory
|
||||
|
||||
CUPY_TO_CTYPES_MAPPING: Dict[Type[np.number], Type[CNumeric]] = {
|
||||
cupy.float32: ctypes.c_float,
|
||||
cupy.uint32: ctypes.c_uint,
|
||||
}
|
||||
if dtype not in CUPY_TO_CTYPES_MAPPING:
|
||||
raise RuntimeError(f"Supported types: {CUPY_TO_CTYPES_MAPPING.keys()}")
|
||||
addr = ctypes.cast(cptr, ctypes.c_void_p).value
|
||||
# pylint: disable=c-extension-no-member,no-member
|
||||
device = cupy.cuda.runtime.pointerGetAttributes(addr).device
|
||||
# The owner field is just used to keep the memory alive with ref count. As
|
||||
# unowned's life time is scoped within this function we don't need that.
|
||||
unownd = UnownedMemory(
|
||||
addr, length * ctypes.sizeof(CUPY_TO_CTYPES_MAPPING[dtype]), owner=None
|
||||
)
|
||||
memptr = MemoryPointer(unownd, 0)
|
||||
# pylint: disable=unexpected-keyword-arg
|
||||
mem = cupy.ndarray((length,), dtype=dtype, memptr=memptr)
|
||||
assert mem.device.id == device
|
||||
arr = cupy.array(mem, copy=True)
|
||||
return arr
|
||||
|
||||
|
||||
def ctypes2buffer(cptr: CStrPtr, length: int) -> bytearray:
|
||||
"""Convert ctypes pointer to buffer type."""
|
||||
if not isinstance(cptr, ctypes.POINTER(ctypes.c_char)):
|
||||
@ -466,14 +436,8 @@ def from_array_interface(interface: dict) -> NumpyOrCupy:
|
||||
|
||||
if "stream" in interface:
|
||||
# CUDA stream is presented, this is a __cuda_array_interface__.
|
||||
spec = importlib.util.find_spec("cupy")
|
||||
if spec is None:
|
||||
raise ImportError("`cupy` is required for handling CUDA buffer.")
|
||||
|
||||
import cupy as cp # pylint: disable=import-error
|
||||
|
||||
arr.__cuda_array_interface__ = interface
|
||||
out = cp.array(arr, copy=True)
|
||||
out = import_cupy().array(arr, copy=True)
|
||||
else:
|
||||
arr.__array_interface__ = interface
|
||||
out = np.array(arr, copy=True)
|
||||
@ -481,17 +445,42 @@ def from_array_interface(interface: dict) -> NumpyOrCupy:
|
||||
return out
|
||||
|
||||
|
||||
def make_array_interface(
|
||||
ptr: CNumericPtr, shape: Tuple[int, ...], dtype: Type[np.number], is_cuda: bool
|
||||
) -> Dict[str, Union[int, tuple, None]]:
|
||||
"""Make an __(cuda)_array_interface__ from a pointer."""
|
||||
# Use an empty array to handle typestr and descr
|
||||
if is_cuda:
|
||||
empty = import_cupy().empty(shape=(0,), dtype=dtype)
|
||||
array = empty.__cuda_array_interface__ # pylint: disable=no-member
|
||||
else:
|
||||
empty = np.empty(shape=(0,), dtype=dtype)
|
||||
array = empty.__array_interface__ # pylint: disable=no-member
|
||||
|
||||
addr = ctypes.cast(ptr, ctypes.c_void_p).value
|
||||
length = int(np.prod(shape))
|
||||
# Handle empty dataset.
|
||||
assert addr is not None or length == 0
|
||||
|
||||
if addr is None:
|
||||
return array
|
||||
|
||||
array["data"] = (addr, True)
|
||||
if is_cuda:
|
||||
array["stream"] = 2
|
||||
array["shape"] = shape
|
||||
array["strides"] = None
|
||||
return array
|
||||
|
||||
|
||||
def _prediction_output(
|
||||
shape: CNumericPtr, dims: c_bst_ulong, predts: CFloatPtr, is_cuda: bool
|
||||
) -> NumpyOrCupy:
|
||||
arr_shape = ctypes2numpy(shape, dims.value, np.uint64)
|
||||
length = int(np.prod(arr_shape))
|
||||
if is_cuda:
|
||||
arr_predict = ctypes2cupy(predts, length, np.float32)
|
||||
else:
|
||||
arr_predict = ctypes2numpy(predts, length, np.float32)
|
||||
arr_predict = arr_predict.reshape(arr_shape)
|
||||
return arr_predict
|
||||
arr_shape = tuple(ctypes2numpy(shape, dims.value, np.uint64).flatten())
|
||||
array = from_array_interface(
|
||||
make_array_interface(predts, arr_shape, np.float32, is_cuda)
|
||||
)
|
||||
return array
|
||||
|
||||
|
||||
class DataIter(ABC): # pylint: disable=too-many-instance-attributes
|
||||
|
||||
@ -1097,6 +1097,8 @@ inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT
|
||||
dh::safe_cuda(cudaEventRecord(event_, cudaStream_t{stream}));
|
||||
}
|
||||
|
||||
// Changing this has effect on prediction return, where we need to pass the pointer to
|
||||
// third-party libraries like cuPy
|
||||
inline CUDAStreamView DefaultStream() {
|
||||
#ifdef CUDA_API_PER_THREAD_DEFAULT_STREAM
|
||||
return CUDAStreamView{cudaStreamPerThread};
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user