Enhance inplace prediction. (#6653)

* Accept array interface for csr and array.
* Accept an optional proxy dmatrix for metainfo.

This constructs an explicit `_ProxyDMatrix` type in Python.

* Remove unused doc.
* Add strict output.
This commit is contained in:
Jiaming Yuan 2021-02-02 11:41:46 +08:00 committed by GitHub
parent 87ab1ad607
commit 411592a347
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 955 additions and 530 deletions

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2014-2020 by Contributors * Copyright 2014-2021 by Contributors
* \file gbm.h * \file gbm.h
* \brief Interface of gradient booster, * \brief Interface of gradient booster,
* that learns through gradient statistics. * that learns through gradient statistics.
@ -118,7 +118,7 @@ class GradientBooster : public Model, public Configurable {
* \param layer_begin (Optional) Begining of boosted tree layer used for prediction. * \param layer_begin (Optional) Begining of boosted tree layer used for prediction.
* \param layer_end (Optional) End of booster layer. 0 means do not limit trees. * \param layer_end (Optional) End of booster layer. 0 means do not limit trees.
*/ */
virtual void InplacePredict(dmlc::any const &, float, virtual void InplacePredict(dmlc::any const &, std::shared_ptr<DMatrix>, float,
PredictionCacheEntry*, PredictionCacheEntry*,
uint32_t, uint32_t,
uint32_t) const { uint32_t) const {

View File

@ -308,6 +308,7 @@ struct StringView {
public: public:
StringView() = default; StringView() = default;
StringView(CharT const* str, size_t size) : str_{str}, size_{size} {} StringView(CharT const* str, size_t size) : str_{str}, size_{size} {}
explicit StringView(std::string const& str): str_{str.c_str()}, size_{str.size()} {}
explicit StringView(CharT const* str) : str_{str}, size_{Traits::length(str)} {} explicit StringView(CharT const* str) : str_{str}, size_{Traits::length(str)} {}
CharT const& operator[](size_t p) const { return str_[p]; } CharT const& operator[](size_t p) const { return str_[p]; }

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2015-2020 by Contributors * Copyright 2015-2021 by Contributors
* \file learner.h * \file learner.h
* \brief Learner interface that integrates objective, gbm and evaluation together. * \brief Learner interface that integrates objective, gbm and evaluation together.
* This is the user facing XGBoost training module. * This is the user facing XGBoost training module.
@ -30,6 +30,15 @@ class ObjFunction;
class DMatrix; class DMatrix;
class Json; class Json;
enum class PredictionType : std::uint8_t { // NOLINT
kValue = 0,
kMargin = 1,
kContribution = 2,
kApproxContribution = 3,
kInteraction = 4,
kLeaf = 5
};
/*! \brief entry to to easily hold returning information */ /*! \brief entry to to easily hold returning information */
struct XGBAPIThreadLocalEntry { struct XGBAPIThreadLocalEntry {
/*! \brief result holder for returning string */ /*! \brief result holder for returning string */
@ -42,7 +51,10 @@ struct XGBAPIThreadLocalEntry {
std::vector<bst_float> ret_vec_float; std::vector<bst_float> ret_vec_float;
/*! \brief temp variable of gradient pairs. */ /*! \brief temp variable of gradient pairs. */
std::vector<GradientPair> tmp_gpair; std::vector<GradientPair> tmp_gpair;
/*! \brief Temp variable for returing prediction result. */
PredictionCacheEntry prediction_entry; PredictionCacheEntry prediction_entry;
/*! \brief Temp variable for returing prediction shape. */
std::vector<bst_ulong> prediction_shape;
}; };
/*! /*!
@ -123,13 +135,17 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
* \brief Inplace prediction. * \brief Inplace prediction.
* *
* \param x A type erased data adapter. * \param x A type erased data adapter.
* \param p_m An optional Proxy DMatrix object storing meta info like
* base margin. Can be nullptr.
* \param type Prediction type. * \param type Prediction type.
* \param missing Missing value in the data. * \param missing Missing value in the data.
* \param [in,out] out_preds Pointer to output prediction vector. * \param [in,out] out_preds Pointer to output prediction vector.
* \param layer_begin (Optional) Begining of boosted tree layer used for prediction. * \param layer_begin Begining of boosted tree layer used for prediction.
* \param layer_end (Optional) End of booster layer. 0 means do not limit trees. * \param layer_end End of booster layer. 0 means do not limit trees.
*/ */
virtual void InplacePredict(dmlc::any const& x, std::string const& type, virtual void InplacePredict(dmlc::any const &x,
std::shared_ptr<DMatrix> p_m,
PredictionType type,
float missing, float missing,
HostDeviceVector<bst_float> **out_preds, HostDeviceVector<bst_float> **out_preds,
uint32_t layer_begin, uint32_t layer_end) = 0; uint32_t layer_begin, uint32_t layer_end) = 0;
@ -138,6 +154,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
* \brief Get number of boosted rounds from gradient booster. * \brief Get number of boosted rounds from gradient booster.
*/ */
virtual int32_t BoostedRounds() const = 0; virtual int32_t BoostedRounds() const = 0;
virtual uint32_t Groups() const = 0;
void LoadModel(Json const& in) override = 0; void LoadModel(Json const& in) override = 0;
void SaveModel(Json* out) const override = 0; void SaveModel(Json* out) const override = 0;

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2017-2020 by Contributors * Copyright 2017-2021 by Contributors
* \file predictor.h * \file predictor.h
* \brief Interface of predictor, * \brief Interface of predictor,
* performs predictions for a gradient booster. * performs predictions for a gradient booster.
@ -142,10 +142,14 @@ class Predictor {
* \param [in,out] out_preds The output preds. * \param [in,out] out_preds The output preds.
* \param tree_begin (Optional) Begining of boosted trees used for prediction. * \param tree_begin (Optional) Begining of boosted trees used for prediction.
* \param tree_end (Optional) End of booster trees. 0 means do not limit trees. * \param tree_end (Optional) End of booster trees. 0 means do not limit trees.
*
* \return True if the data can be handled by current predictor, false otherwise.
*/ */
virtual void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model, virtual bool InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
float missing, PredictionCacheEntry *out_preds, const gbm::GBTreeModel &model, float missing,
uint32_t tree_begin = 0, uint32_t tree_end = 0) const = 0; PredictionCacheEntry *out_preds,
uint32_t tree_begin = 0,
uint32_t tree_end = 0) const = 0;
/** /**
* \brief online prediction function, predict score for one instance at a time * \brief online prediction function, predict score for one instance at a time
* NOTE: use the batch prediction interface if possible, batch prediction is * NOTE: use the batch prediction interface if possible, batch prediction is

View File

@ -58,21 +58,23 @@ CallbackEnv = collections.namedtuple(
"evaluation_result_list"]) "evaluation_result_list"])
def from_pystr_to_cstr(data): def from_pystr_to_cstr(data: Union[str, List[str]]):
"""Convert a list of Python str to C pointer """Convert a Python str or list of Python str to C pointer
Parameters Parameters
---------- ----------
data : list data
list of str str or list of str
""" """
if not isinstance(data, list): if isinstance(data, str):
raise NotImplementedError return bytes(data, "utf-8")
if isinstance(data, list):
pointers = (ctypes.c_char_p * len(data))() pointers = (ctypes.c_char_p * len(data))()
data = [bytes(d, 'utf-8') for d in data] data = [bytes(d, 'utf-8') for d in data]
pointers[:] = data pointers[:] = data
return pointers return pointers
raise TypeError()
def from_cstr_to_pystr(data, length): def from_cstr_to_pystr(data, length):
@ -190,21 +192,40 @@ def _check_call(ret):
raise XGBoostError(py_str(_LIB.XGBGetLastError())) raise XGBoostError(py_str(_LIB.XGBGetLastError()))
def ctypes2numpy(cptr, length, dtype) -> np.ndarray: def _numpy2ctypes_type(dtype):
"""Convert a ctypes pointer array to a numpy array.""" _NUMPY_TO_CTYPES_MAPPING = {
NUMPY_TO_CTYPES_MAPPING = {
np.float32: ctypes.c_float, np.float32: ctypes.c_float,
np.float64: ctypes.c_double,
np.uint32: ctypes.c_uint, np.uint32: ctypes.c_uint,
np.uint64: ctypes.c_uint64,
np.int32: ctypes.c_int32,
np.int64: ctypes.c_int64,
} }
if dtype not in NUMPY_TO_CTYPES_MAPPING: if np.intc is not np.int32: # Windows
raise RuntimeError('Supported types: {}'.format( _NUMPY_TO_CTYPES_MAPPING[np.intc] = _NUMPY_TO_CTYPES_MAPPING[np.int32]
NUMPY_TO_CTYPES_MAPPING.keys())) if dtype not in _NUMPY_TO_CTYPES_MAPPING.keys():
ctype = NUMPY_TO_CTYPES_MAPPING[dtype] raise TypeError(
f"Supported types: {_NUMPY_TO_CTYPES_MAPPING.keys()}, got: {dtype}"
)
return _NUMPY_TO_CTYPES_MAPPING[dtype]
def _array_interface(data: np.ndarray) -> bytes:
interface = data.__array_interface__
if "mask" in interface:
interface["mask"] = interface["mask"].__array_interface__
interface_str = bytes(json.dumps(interface, indent=2), "utf-8")
return interface_str
def ctypes2numpy(cptr, length, dtype):
"""Convert a ctypes pointer array to a numpy array."""
ctype = _numpy2ctypes_type(dtype)
if not isinstance(cptr, ctypes.POINTER(ctype)): if not isinstance(cptr, ctypes.POINTER(ctype)):
raise RuntimeError('expected {} pointer'.format(ctype)) raise RuntimeError("expected {} pointer".format(ctype))
res = np.zeros(length, dtype=dtype) res = np.zeros(length, dtype=dtype)
if not ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]): if not ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]):
raise RuntimeError('memmove failed') raise RuntimeError("memmove failed")
return res return res
@ -214,25 +235,21 @@ def ctypes2cupy(cptr, length, dtype):
import cupy import cupy
from cupy.cuda.memory import MemoryPointer from cupy.cuda.memory import MemoryPointer
from cupy.cuda.memory import UnownedMemory from cupy.cuda.memory import UnownedMemory
CUPY_TO_CTYPES_MAPPING = {
cupy.float32: ctypes.c_float, CUPY_TO_CTYPES_MAPPING = {cupy.float32: ctypes.c_float, cupy.uint32: ctypes.c_uint}
cupy.uint32: ctypes.c_uint
}
if dtype not in CUPY_TO_CTYPES_MAPPING.keys(): if dtype not in CUPY_TO_CTYPES_MAPPING.keys():
raise RuntimeError('Supported types: {}'.format( raise RuntimeError("Supported types: {}".format(CUPY_TO_CTYPES_MAPPING.keys()))
CUPY_TO_CTYPES_MAPPING.keys()
))
addr = ctypes.cast(cptr, ctypes.c_void_p).value addr = ctypes.cast(cptr, ctypes.c_void_p).value
# pylint: disable=c-extension-no-member,no-member # pylint: disable=c-extension-no-member,no-member
device = cupy.cuda.runtime.pointerGetAttributes(addr).device device = cupy.cuda.runtime.pointerGetAttributes(addr).device
# The owner field is just used to keep the memory alive with ref count. As # The owner field is just used to keep the memory alive with ref count. As
# unowned's life time is scoped within this function we don't need that. # unowned's life time is scoped within this function we don't need that.
unownd = UnownedMemory( unownd = UnownedMemory(
addr, length.value * ctypes.sizeof(CUPY_TO_CTYPES_MAPPING[dtype]), addr, length * ctypes.sizeof(CUPY_TO_CTYPES_MAPPING[dtype]), owner=None
owner=None) )
memptr = MemoryPointer(unownd, 0) memptr = MemoryPointer(unownd, 0)
# pylint: disable=unexpected-keyword-arg # pylint: disable=unexpected-keyword-arg
mem = cupy.ndarray((length.value, ), dtype=dtype, memptr=memptr) mem = cupy.ndarray((length,), dtype=dtype, memptr=memptr)
assert mem.device.id == device assert mem.device.id == device
arr = cupy.array(mem, copy=True) arr = cupy.array(mem, copy=True)
return arr return arr
@ -256,28 +273,29 @@ def c_str(string):
def c_array(ctype, values): def c_array(ctype, values):
"""Convert a python string to c array.""" """Convert a python string to c array."""
if (isinstance(values, np.ndarray) if isinstance(values, np.ndarray) and values.dtype.itemsize == ctypes.sizeof(ctype):
and values.dtype.itemsize == ctypes.sizeof(ctype)):
return (ctype * len(values)).from_buffer_copy(values) return (ctype * len(values)).from_buffer_copy(values)
return (ctype * len(values))(*values) return (ctype * len(values))(*values)
def _prediction_output(shape, dims, predts, is_cuda):
arr_shape: np.ndarray = ctypes2numpy(shape, dims.value, np.uint64)
length = int(np.prod(arr_shape))
if is_cuda:
arr_predict = ctypes2cupy(predts, length, np.float32)
else:
arr_predict: np.ndarray = ctypes2numpy(predts, length, np.float32)
arr_predict = arr_predict.reshape(arr_shape)
return arr_predict
class DataIter: class DataIter:
'''The interface for user defined data iterator. Currently is only '''The interface for user defined data iterator. Currently is only supported by Device
supported by Device DMatrix. DMatrix.
Parameters
----------
rows : int
Total number of rows combining all batches.
cols : int
Number of columns for each batch.
''' '''
def __init__(self): def __init__(self):
proxy_handle = ctypes.c_void_p() self._handle = _ProxyDMatrix()
_check_call(_LIB.XGProxyDMatrixCreate(ctypes.byref(proxy_handle)))
self._handle = DeviceQuantileDMatrix(proxy_handle)
self.exception = None self.exception = None
@property @property
@ -300,12 +318,7 @@ class DataIter:
if self.exception is not None: if self.exception is not None:
return 0 return 0
def data_handle(data, label=None, weight=None, base_margin=None, def data_handle(data, feature_names=None, feature_types=None, **kwargs):
group=None,
qid=None,
label_lower_bound=None, label_upper_bound=None,
feature_names=None, feature_types=None,
feature_weights=None):
from .data import dispatch_device_quantile_dmatrix_set_data from .data import dispatch_device_quantile_dmatrix_set_data
from .data import _device_quantile_transform from .data import _device_quantile_transform
data, feature_names, feature_types = _device_quantile_transform( data, feature_names, feature_types = _device_quantile_transform(
@ -313,16 +326,9 @@ class DataIter:
) )
dispatch_device_quantile_dmatrix_set_data(self.proxy, data) dispatch_device_quantile_dmatrix_set_data(self.proxy, data)
self.proxy.set_info( self.proxy.set_info(
label=label,
weight=weight,
base_margin=base_margin,
group=group,
qid=qid,
label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound,
feature_names=feature_names, feature_names=feature_names,
feature_types=feature_types, feature_types=feature_types,
feature_weights=feature_weights **kwargs,
) )
try: try:
# Differ the exception in order to return 0 and stop the iteration. # Differ the exception in order to return 0 and stop the iteration.
@ -558,7 +564,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
feature_types=None, feature_types=None,
feature_weights=None feature_weights=None
) -> None: ) -> None:
"""Set meta info for DMatrix. See doc string for DMatrix constructor.""" """Set meta info for DMatrix. See doc string for :py:obj:`xgboost.DMatrix`."""
from .data import dispatch_meta_backend from .data import dispatch_meta_backend
if label is not None: if label is not None:
@ -959,18 +965,52 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
c_bst_ulong(0))) c_bst_ulong(0)))
class _ProxyDMatrix(DMatrix):
"""A placeholder class when DMatrix cannot be constructed (DeviceQuantileDMatrix,
inplace_predict).
"""
def __init__(self): # pylint: disable=super-init-not-called
self.handle = ctypes.c_void_p()
_check_call(_LIB.XGProxyDMatrixCreate(ctypes.byref(self.handle)))
def _set_data_from_cuda_interface(self, data):
'''Set data from CUDA array interface.'''
interface = data.__cuda_array_interface__
interface_str = bytes(json.dumps(interface, indent=2), 'utf-8')
_check_call(
_LIB.XGDeviceQuantileDMatrixSetDataCudaArrayInterface(
self.handle,
interface_str
)
)
def _set_data_from_cuda_columnar(self, data):
'''Set data from CUDA columnar format.1'''
from .data import _cudf_array_interfaces
interfaces_str = _cudf_array_interfaces(data)
_check_call(
_LIB.XGDeviceQuantileDMatrixSetDataCudaColumnar(
self.handle,
interfaces_str
)
)
class DeviceQuantileDMatrix(DMatrix): class DeviceQuantileDMatrix(DMatrix):
"""Device memory Data Matrix used in XGBoost for training with """Device memory Data Matrix used in XGBoost for training with tree_method='gpu_hist'. Do
tree_method='gpu_hist'. Do not use this for test/validation tasks as some not use this for test/validation tasks as some information may be lost in
information may be lost in quantisation. This DMatrix is primarily designed quantisation. This DMatrix is primarily designed to save memory in training from
to save memory in training from device memory inputs by avoiding device memory inputs by avoiding intermediate storage. Set max_bin to control the
intermediate storage. Set max_bin to control the number of bins during number of bins during quantisation. See doc string in :py:obj:`xgboost.DMatrix` for
quantisation. See doc string in `DMatrix` for documents on meta info. documents on meta info.
You can construct DeviceQuantileDMatrix from cupy/cudf/dlpack. You can construct DeviceQuantileDMatrix from cupy/cudf/dlpack.
.. versionadded:: 1.1.0 .. versionadded:: 1.1.0
""" """
@_deprecate_positional_args @_deprecate_positional_args
def __init__( # pylint: disable=super-init-not-called def __init__( # pylint: disable=super-init-not-called
self, self,
@ -1000,58 +1040,72 @@ class DeviceQuantileDMatrix(DMatrix):
if isinstance(data, ctypes.c_void_p): if isinstance(data, ctypes.c_void_p):
self.handle = data self.handle = data
return return
from .data import init_device_quantile_dmatrix
handle, feature_names, feature_types = init_device_quantile_dmatrix(
data,
label=label, weight=weight,
base_margin=base_margin,
group=group,
qid=qid,
missing=self.missing,
label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound,
feature_weights=feature_weights,
feature_names=feature_names,
feature_types=feature_types,
threads=self.nthread,
max_bin=self.max_bin,
)
if enable_categorical: if enable_categorical:
raise NotImplementedError( raise NotImplementedError(
'categorical support is not enabled on DeviceQuantileDMatrix.' 'categorical support is not enabled on DeviceQuantileDMatrix.'
) )
self.handle = handle
if qid is not None and group is not None: if qid is not None and group is not None:
raise ValueError( raise ValueError(
'Only one of the eval_qid or eval_group for each evaluation ' 'Only one of the eval_qid or eval_group for each evaluation '
'dataset should be provided.' 'dataset should be provided.'
) )
self.feature_names = feature_names self._init(
self.feature_types = feature_types data,
label=label,
def _set_data_from_cuda_interface(self, data): weight=weight,
'''Set data from CUDA array interface.''' base_margin=base_margin,
interface = data.__cuda_array_interface__ group=group,
interface_str = bytes(json.dumps(interface, indent=2), 'utf-8') qid=qid,
_check_call( label_lower_bound=label_lower_bound,
_LIB.XGDeviceQuantileDMatrixSetDataCudaArrayInterface( label_upper_bound=label_upper_bound,
self.handle, feature_weights=feature_weights,
interface_str feature_names=feature_names,
) feature_types=feature_types,
) )
def _set_data_from_cuda_columnar(self, data): def _init(self, data, feature_names, feature_types, **meta):
'''Set data from CUDA columnar format.1''' from .data import (
from .data import _cudf_array_interfaces _is_dlpack,
interfaces_str = _cudf_array_interfaces(data) _transform_dlpack,
_check_call( _is_iter,
_LIB.XGDeviceQuantileDMatrixSetDataCudaColumnar( SingleBatchInternalIter,
self.handle,
interfaces_str
) )
if _is_dlpack(data):
# We specialize for dlpack because cupy will take the memory from it so
# it can't be transformed twice.
data = _transform_dlpack(data)
if _is_iter(data):
it = data
else:
it = SingleBatchInternalIter(
data, **meta, feature_names=feature_names, feature_types=feature_types
) )
reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(it.reset_wrapper)
next_callback = ctypes.CFUNCTYPE(
ctypes.c_int,
ctypes.c_void_p,
)(it.next_wrapper)
handle = ctypes.c_void_p()
ret = _LIB.XGDeviceQuantileDMatrixCreateFromCallback(
None,
it.proxy.handle,
reset_callback,
next_callback,
ctypes.c_float(self.missing),
ctypes.c_int(self.nthread),
ctypes.c_int(self.max_bin),
ctypes.byref(handle),
)
if it.exception:
raise it.exception
# delay check_call to throw intermediate exception first
_check_call(ret)
self.handle = handle
Objective = Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]] Objective = Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]]
Metric = Callable[[np.ndarray, DMatrix], Tuple[str, float]] Metric = Callable[[np.ndarray, DMatrix], Tuple[str, float]]
@ -1346,7 +1400,7 @@ class Booster(object):
def boost(self, dtrain, grad, hess): def boost(self, dtrain, grad, hess):
"""Boost the booster for one iteration, with customized gradient """Boost the booster for one iteration, with customized gradient
statistics. Like :func:`xgboost.core.Booster.update`, this statistics. Like :py:func:`xgboost.Booster.update`, this
function should not be called directly by users. function should not be called directly by users.
Parameters Parameters
@ -1360,7 +1414,9 @@ class Booster(object):
""" """
if len(grad) != len(hess): if len(grad) != len(hess):
raise ValueError('grad / hess length mismatch: {} / {}'.format(len(grad), len(hess))) raise ValueError(
'grad / hess length mismatch: {} / {}'.format(len(grad), len(hess))
)
if not isinstance(dtrain, DMatrix): if not isinstance(dtrain, DMatrix):
raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__)) raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__))
self._validate_features(dtrain) self._validate_features(dtrain)
@ -1453,17 +1509,12 @@ class Booster(object):
training=False): training=False):
"""Predict with data. """Predict with data.
.. note:: This function is not thread safe except for ``gbtree`` .. note:: This function is not thread safe except for ``gbtree`` booster.
booster.
For ``gbtree`` booster, the thread safety is guaranteed by locks. When using booster other than ``gbtree``, predict can only be called from one
For lock free prediction use ``inplace_predict`` instead. Also, the thread. If you want to run prediction using multiple thread, call
safety does not hold when used in conjunction with other methods. :py:meth:`xgboost.Booster.copy` to make copies of model object and then call
``predict()``.
When using booster other than ``gbtree``, predict can only be called
from one thread. If you want to run prediction using multiple
thread, call ``bst.copy()`` to make copies of model object and then
call ``predict()``.
Parameters Parameters
---------- ----------
@ -1579,9 +1630,17 @@ class Booster(object):
preds = preds.reshape(nrow, chunk_size) preds = preds.reshape(nrow, chunk_size)
return preds return preds
def inplace_predict(self, data, iteration_range=(0, 0), def inplace_predict(
predict_type='value', missing=np.nan): self,
'''Run prediction in-place, Unlike ``predict`` method, inplace prediction does data,
iteration_range: Tuple[int, int] = (0, 0),
predict_type: str = "value",
missing: float = np.nan,
validate_features: bool = True,
base_margin: Any = None,
strict_shape: bool = False
):
"""Run prediction in-place, Unlike ``predict`` method, inplace prediction does
not cache the prediction result. not cache the prediction result.
Calling only ``inplace_predict`` in multiple threads is safe and lock Calling only ``inplace_predict`` in multiple threads is safe and lock
@ -1617,6 +1676,15 @@ class Booster(object):
missing : float missing : float
Value in the input data which needs to be present as a missing Value in the input data which needs to be present as a missing
value. value.
validate_features:
See :py:meth:`xgboost.Booster.predict` for details.
base_margin:
See :py:obj:`xgboost.DMatrix` for details.
strict_shape:
When set to True, output shape is invariant to whether classification is used.
For both value and margin prediction, the output shape is (n_samples,
n_groups), n_groups == 1 when multi-class is not used. Default to False, in
which case the output shape can be (n_samples, ) if multi-class is not used.
Returns Returns
------- -------
@ -1624,107 +1692,117 @@ class Booster(object):
The prediction result. When input data is on GPU, prediction The prediction result. When input data is on GPU, prediction
result is stored in a cupy array. result is stored in a cupy array.
''' """
def reshape_output(predt, rows):
'''Reshape for multi-output prediction.'''
if predt.size != rows and predt.size % rows == 0:
cols = int(predt.size / rows)
predt = predt.reshape(rows, cols)
return predt
return predt
length = c_bst_ulong()
preds = ctypes.POINTER(ctypes.c_float)() preds = ctypes.POINTER(ctypes.c_float)()
iteration_range = (ctypes.c_uint(iteration_range[0]),
ctypes.c_uint(iteration_range[1]))
# once caching is supported, we can pass id(data) as cache id. # once caching is supported, we can pass id(data) as cache id.
try: try:
import pandas as pd import pandas as pd
if isinstance(data, pd.DataFrame): if isinstance(data, pd.DataFrame):
data = data.values data = data.values
except ImportError: except ImportError:
pass pass
args = {
"type": 0,
"training": False,
"iteration_begin": iteration_range[0],
"iteration_end": iteration_range[1],
"missing": missing,
"strict_shape": strict_shape,
"cache_id": 0,
}
if predict_type == "margin":
args["type"] = 1
shape = ctypes.POINTER(c_bst_ulong)()
dims = c_bst_ulong()
if base_margin is not None:
proxy = _ProxyDMatrix()
proxy.set_info(base_margin=base_margin)
p_handle = proxy.handle
else:
proxy = None
p_handle = ctypes.c_void_p()
assert proxy is None or isinstance(proxy, _ProxyDMatrix)
if validate_features:
if len(data.shape) != 1 and self.num_features() != data.shape[1]:
raise ValueError(
f"Feature shape mismatch, expected: {self.num_features()}, "
f"got {data.shape[0]}"
)
if isinstance(data, np.ndarray): if isinstance(data, np.ndarray):
assert data.flags.c_contiguous from .data import _maybe_np_slice
arr = np.array(data.reshape(data.size), copy=False, data = _maybe_np_slice(data, data.dtype)
dtype=np.float32) _check_call(
_check_call(_LIB.XGBoosterPredictFromDense( _LIB.XGBoosterPredictFromDense(
self.handle, self.handle,
arr.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), _array_interface(data),
c_bst_ulong(data.shape[0]), from_pystr_to_cstr(json.dumps(args)),
c_bst_ulong(data.shape[1]), p_handle,
ctypes.c_float(missing), ctypes.byref(shape),
iteration_range[0], ctypes.byref(dims),
iteration_range[1], ctypes.byref(preds),
c_str(predict_type), )
c_bst_ulong(0), )
ctypes.byref(length), return _prediction_output(shape, dims, preds, False)
ctypes.byref(preds)
))
preds = ctypes2numpy(preds, length.value, np.float32)
rows = data.shape[0]
return reshape_output(preds, rows)
if isinstance(data, scipy.sparse.csr_matrix): if isinstance(data, scipy.sparse.csr_matrix):
csr = data csr = data
_check_call(_LIB.XGBoosterPredictFromCSR( _check_call(
_LIB.XGBoosterPredictFromCSR(
self.handle, self.handle,
c_array(ctypes.c_size_t, csr.indptr), _array_interface(csr.indptr),
c_array(ctypes.c_uint, csr.indices), _array_interface(csr.indices),
c_array(ctypes.c_float, csr.data), _array_interface(csr.data),
ctypes.c_size_t(len(csr.indptr)),
ctypes.c_size_t(len(csr.data)),
ctypes.c_size_t(csr.shape[1]), ctypes.c_size_t(csr.shape[1]),
ctypes.c_float(missing), from_pystr_to_cstr(json.dumps(args)),
iteration_range[0], p_handle,
iteration_range[1], ctypes.byref(shape),
c_str(predict_type), ctypes.byref(dims),
c_bst_ulong(0), ctypes.byref(preds),
ctypes.byref(length), )
ctypes.byref(preds))) )
preds = ctypes2numpy(preds, length.value, np.float32) return _prediction_output(shape, dims, preds, False)
rows = data.shape[0] if lazy_isinstance(data, "cupy.core.core", "ndarray"):
return reshape_output(preds, rows) from .data import _transform_cupy_array
if lazy_isinstance(data, 'cupy.core.core', 'ndarray'): data = _transform_cupy_array(data)
assert data.flags.c_contiguous
interface = data.__cuda_array_interface__ interface = data.__cuda_array_interface__
if 'mask' in interface: if "mask" in interface:
interface['mask'] = interface['mask'].__cuda_array_interface__ interface["mask"] = interface["mask"].__cuda_array_interface__
interface_str = bytes(json.dumps(interface, indent=2), 'utf-8') interface_str = bytes(json.dumps(interface, indent=2), "utf-8")
_check_call(_LIB.XGBoosterPredictFromArrayInterface( _check_call(
_LIB.XGBoosterPredictFromArrayInterface(
self.handle, self.handle,
interface_str, interface_str,
ctypes.c_float(missing), from_pystr_to_cstr(json.dumps(args)),
iteration_range[0], p_handle,
iteration_range[1], ctypes.byref(shape),
c_str(predict_type), ctypes.byref(dims),
c_bst_ulong(0), ctypes.byref(preds),
ctypes.byref(length), )
ctypes.byref(preds))) )
mem = ctypes2cupy(preds, length, np.float32) return _prediction_output(shape, dims, preds, True)
rows = data.shape[0] if lazy_isinstance(data, "cudf.core.dataframe", "DataFrame"):
return reshape_output(mem, rows)
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
from .data import _cudf_array_interfaces from .data import _cudf_array_interfaces
interfaces_str = _cudf_array_interfaces(data) interfaces_str = _cudf_array_interfaces(data)
_check_call(_LIB.XGBoosterPredictFromArrayInterfaceColumns( _check_call(
_LIB.XGBoosterPredictFromArrayInterfaceColumns(
self.handle, self.handle,
interfaces_str, interfaces_str,
ctypes.c_float(missing), from_pystr_to_cstr(json.dumps(args)),
iteration_range[0], p_handle,
iteration_range[1], ctypes.byref(shape),
c_str(predict_type), ctypes.byref(dims),
c_bst_ulong(0), ctypes.byref(preds),
ctypes.byref(length), )
ctypes.byref(preds))) )
mem = ctypes2cupy(preds, length, np.float32) return _prediction_output(shape, dims, preds, True)
rows = data.shape[0]
predt = reshape_output(mem, rows)
return predt
raise TypeError('Data type:' + str(type(data)) + raise TypeError(
' not supported by inplace prediction.') "Data type:" + str(type(data)) + " not supported by inplace prediction."
)
def save_model(self, fname): def save_model(self, fname):
"""Save the model to a file. """Save the model to a file.

View File

@ -187,8 +187,8 @@ class DaskDMatrix:
`DaskDMatrix` forces all lazy computation to be carried out. Wait for the input data `DaskDMatrix` forces all lazy computation to be carried out. Wait for the input data
explicitly if you want to see actual computation of constructing `DaskDMatrix`. explicitly if you want to see actual computation of constructing `DaskDMatrix`.
See doc string for DMatrix constructor for other parameters. DaskDMatrix accepts only See doc for :py:obj:`xgboost.DMatrix` constructor for other parameters. DaskDMatrix
dask collection. accepts only dask collection.
.. note:: .. note::
@ -575,7 +575,8 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
memory usage by eliminating data copies. Internally the all partitions/chunks of data memory usage by eliminating data copies. Internally the all partitions/chunks of data
are merged by weighted GK sketching. So the number of partitions from dask may affect are merged by weighted GK sketching. So the number of partitions from dask may affect
training accuracy as GK generates bounded error for each merge. See doc string for training accuracy as GK generates bounded error for each merge. See doc string for
`DeviceQuantileDMatrix` and `DMatrix` for other parameters. :py:obj:`xgboost.DeviceQuantileDMatrix` and :py:obj:`xgboost.DMatrix` for other
parameters.
.. versionadded:: 1.2.0 .. versionadded:: 1.2.0

View File

@ -5,11 +5,12 @@ import ctypes
import json import json
import warnings import warnings
import os import os
from typing import Any
import numpy as np import numpy as np
from .core import c_array, _LIB, _check_call, c_str from .core import c_array, _LIB, _check_call, c_str
from .core import DataIter, DeviceQuantileDMatrix, DMatrix from .core import DataIter, _ProxyDMatrix, DMatrix
from .compat import lazy_isinstance from .compat import lazy_isinstance
c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name
@ -113,7 +114,7 @@ def _maybe_np_slice(data, dtype):
return data return data
def _transform_np_array(data: np.ndarray): def _transform_np_array(data: np.ndarray) -> np.ndarray:
if not isinstance(data, np.ndarray) and hasattr(data, '__array__'): if not isinstance(data, np.ndarray) and hasattr(data, '__array__'):
data = np.array(data, copy=False) data = np.array(data, copy=False)
if len(data.shape) != 2: if len(data.shape) != 2:
@ -142,7 +143,7 @@ def _from_numpy_array(data, missing, nthread, feature_names, feature_types):
input layout and type if memory use is a concern. input layout and type if memory use is a concern.
""" """
flatten = _transform_np_array(data) flatten: np.ndarray = _transform_np_array(data)
handle = ctypes.c_void_p() handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromMat_omp( _check_call(_LIB.XGDMatrixCreateFromMat_omp(
flatten.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), flatten.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
@ -783,54 +784,6 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
self.it = 0 self.it = 0
def init_device_quantile_dmatrix(
data, missing, max_bin, threads, feature_names, feature_types, **meta
):
'''Constructor for DeviceQuantileDMatrix.'''
if not any([_is_cudf_df(data), _is_cudf_ser(data), _is_cupy_array(data),
_is_dlpack(data), _is_iter(data)]):
raise TypeError(str(type(data)) +
' is not supported for DeviceQuantileDMatrix')
if _is_dlpack(data):
# We specialize for dlpack because cupy will take the memory from it so
# it can't be transformed twice.
data = _transform_dlpack(data)
if _is_iter(data):
it = data
else:
it = SingleBatchInternalIter(
data, **meta, feature_names=feature_names,
feature_types=feature_types)
reset_factory = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
reset_callback = reset_factory(it.reset_wrapper)
next_factory = ctypes.CFUNCTYPE(
ctypes.c_int,
ctypes.c_void_p,
)
next_callback = next_factory(it.next_wrapper)
handle = ctypes.c_void_p()
ret = _LIB.XGDeviceQuantileDMatrixCreateFromCallback(
None,
it.proxy.handle,
reset_callback,
next_callback,
ctypes.c_float(missing),
ctypes.c_int(threads),
ctypes.c_int(max_bin),
ctypes.byref(handle)
)
if it.exception:
raise it.exception
# delay check_call to throw intermediate exception first
_check_call(ret)
matrix = DeviceQuantileDMatrix(handle)
feature_names = matrix.feature_names
feature_types = matrix.feature_types
matrix.handle = None
return handle, feature_names, feature_types
def _device_quantile_transform(data, feature_names, feature_types): def _device_quantile_transform(data, feature_names, feature_types):
if _is_cudf_df(data): if _is_cudf_df(data):
return _transform_cudf_df(data, feature_names, feature_types) return _transform_cudf_df(data, feature_names, feature_types)
@ -845,7 +798,7 @@ def _device_quantile_transform(data, feature_names, feature_types):
str(type(data))) str(type(data)))
def dispatch_device_quantile_dmatrix_set_data(proxy, data): def dispatch_device_quantile_dmatrix_set_data(proxy: _ProxyDMatrix, data: Any) -> None:
'''Dispatch for DeviceQuantileDMatrix.''' '''Dispatch for DeviceQuantileDMatrix.'''
if _is_cudf_df(data): if _is_cudf_df(data):
proxy._set_data_from_cuda_columnar(data) # pylint: disable=W0212 proxy._set_data_from_cuda_columnar(data) # pylint: disable=W0212

View File

@ -21,6 +21,7 @@
#include "xgboost/global_config.h" #include "xgboost/global_config.h"
#include "c_api_error.h" #include "c_api_error.h"
#include "c_api_utils.h"
#include "../common/io.h" #include "../common/io.h"
#include "../common/charconv.h" #include "../common/charconv.h"
#include "../data/adapter.h" #include "../data/adapter.h"
@ -617,89 +618,91 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
API_END(); API_END();
} }
template <typename T>
void InplacePredictImpl(std::shared_ptr<T> x, std::shared_ptr<DMatrix> p_m,
char const *c_json_config, Learner *learner,
size_t n_rows, size_t n_cols,
xgboost::bst_ulong const **out_shape,
xgboost::bst_ulong *out_dim, const float **out_result) {
auto config = Json::Load(StringView{c_json_config});
CHECK_EQ(get<Integer const>(config["cache_id"]), 0) << "Cache ID is not supported yet";
HostDeviceVector<float>* p_predt { nullptr };
auto type = PredictionType(get<Integer const>(config["type"]));
learner->InplacePredict(x, p_m, type, get<Number const>(config["missing"]),
&p_predt,
get<Integer const>(config["iteration_begin"]),
get<Integer const>(config["iteration_end"]));
CHECK(p_predt);
auto &shape = learner->GetThreadLocal().prediction_shape;
auto chunksize = n_rows == 0 ? 0 : p_predt->Size() / n_rows;
bool strict_shape = get<Boolean const>(config["strict_shape"]);
CalcPredictShape(strict_shape, type, n_rows, n_cols, chunksize, learner->Groups(),
learner->BoostedRounds(), &shape, out_dim);
*out_result = dmlc::BeginPtr(p_predt->HostVector());
*out_shape = dmlc::BeginPtr(shape);
}
// A hidden API as cache id is not being supported yet. // A hidden API as cache id is not being supported yet.
XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, float *values, XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle,
xgboost::bst_ulong n_rows, char const *array_interface,
xgboost::bst_ulong n_cols, char const *c_json_config,
float missing, DMatrixHandle m,
unsigned iteration_begin, xgboost::bst_ulong const **out_shape,
unsigned iteration_end, xgboost::bst_ulong *out_dim,
char const* c_type,
xgboost::bst_ulong cache_id,
xgboost::bst_ulong *out_len,
const float **out_result) { const float **out_result) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet"; std::shared_ptr<xgboost::data::ArrayAdapter> x{
new xgboost::data::ArrayAdapter(StringView{array_interface})};
std::shared_ptr<DMatrix> p_m {nullptr};
if (m) {
p_m = *static_cast<std::shared_ptr<DMatrix> *>(m);
}
auto *learner = static_cast<xgboost::Learner *>(handle); auto *learner = static_cast<xgboost::Learner *>(handle);
InplacePredictImpl(x, p_m, c_json_config, learner, x->NumRows(),
std::shared_ptr<xgboost::data::DenseAdapter> x{ x->NumColumns(), out_shape, out_dim, out_result);
new xgboost::data::DenseAdapter(values, n_rows, n_cols)};
HostDeviceVector<float>* p_predt { nullptr };
std::string type { c_type };
learner->InplacePredict(x, type, missing, &p_predt, iteration_begin, iteration_end);
CHECK(p_predt);
*out_result = dmlc::BeginPtr(p_predt->HostVector());
*out_len = static_cast<xgboost::bst_ulong>(p_predt->Size());
API_END(); API_END();
} }
// A hidden API as cache id is not being supported yet. // A hidden API as cache id is not being supported yet.
XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, char const *indptr,
const size_t* indptr, char const *indices, char const *data,
const unsigned* indices, xgboost::bst_ulong cols,
const bst_float* data, char const *c_json_config, DMatrixHandle m,
size_t nindptr, xgboost::bst_ulong const **out_shape,
size_t nelem, xgboost::bst_ulong *out_dim,
size_t num_col,
float missing,
unsigned iteration_begin,
unsigned iteration_end,
char const *c_type,
xgboost::bst_ulong cache_id,
xgboost::bst_ulong *out_len,
const float **out_result) { const float **out_result) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet"; std::shared_ptr<xgboost::data::CSRArrayAdapter> x{
new xgboost::data::CSRArrayAdapter{
StringView{indptr}, StringView{indices}, StringView{data}, cols}};
std::shared_ptr<DMatrix> p_m {nullptr};
if (m) {
p_m = *static_cast<std::shared_ptr<DMatrix> *>(m);
}
auto *learner = static_cast<xgboost::Learner *>(handle); auto *learner = static_cast<xgboost::Learner *>(handle);
InplacePredictImpl(x, p_m, c_json_config, learner, x->NumRows(),
std::shared_ptr<xgboost::data::CSRAdapter> x{ x->NumColumns(), out_shape, out_dim, out_result);
new xgboost::data::CSRAdapter(indptr, indices, data, nindptr - 1, nelem, num_col)};
HostDeviceVector<float>* p_predt { nullptr };
std::string type { c_type };
learner->InplacePredict(x, type, missing, &p_predt, iteration_begin, iteration_end);
CHECK(p_predt);
*out_result = dmlc::BeginPtr(p_predt->HostVector());
*out_len = static_cast<xgboost::bst_ulong>(p_predt->Size());
API_END(); API_END();
} }
#if !defined(XGBOOST_USE_CUDA) #if !defined(XGBOOST_USE_CUDA)
XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(BoosterHandle handle, XGB_DLL int XGBoosterPredictFromArrayInterface(
char const* c_json_strs, BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
float missing, DMatrixHandle m, xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim,
unsigned iteration_begin, const float **out_result) {
unsigned iteration_end,
char const* c_type,
xgboost::bst_ulong cache_id,
xgboost::bst_ulong *out_len,
float const** out_result) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
common::AssertGPUSupport(); common::AssertGPUSupport();
API_END(); API_END();
} }
XGB_DLL int XGBoosterPredictFromArrayInterface(BoosterHandle handle,
char const* c_json_strs, XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(
float missing, BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
unsigned iteration_begin, DMatrixHandle m, xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim,
unsigned iteration_end,
char const* c_type,
xgboost::bst_ulong cache_id,
xgboost::bst_ulong *out_len,
const float **out_result) { const float **out_result) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();

View File

@ -1,8 +1,9 @@
// Copyright (c) 2019-2020 by Contributors // Copyright (c) 2019-2021 by Contributors
#include "xgboost/data.h" #include "xgboost/data.h"
#include "xgboost/c_api.h" #include "xgboost/c_api.h"
#include "xgboost/learner.h" #include "xgboost/learner.h"
#include "c_api_error.h" #include "c_api_error.h"
#include "c_api_utils.h"
#include "../data/device_adapter.cuh" #include "../data/device_adapter.cuh"
using namespace xgboost; // NOLINT using namespace xgboost; // NOLINT
@ -30,59 +31,63 @@ XGB_DLL int XGDMatrixCreateFromArrayInterface(char const* c_json_strs,
API_END(); API_END();
} }
// A hidden API as cache id is not being supported yet. template <typename T>
XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(BoosterHandle handle, int InplacePreidctCuda(BoosterHandle handle, char const *c_json_strs,
char const* c_json_strs, char const *c_json_config,
float missing, std::shared_ptr<DMatrix> p_m,
unsigned iteration_begin, xgboost::bst_ulong const **out_shape,
unsigned iteration_end, xgboost::bst_ulong *out_dim, const float **out_result) {
char const* c_type,
xgboost::bst_ulong cache_id,
xgboost::bst_ulong *out_len,
float const** out_result) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet"; auto config = Json::Load(StringView{c_json_config});
CHECK_EQ(get<Integer const>(config["cache_id"]), 0)
<< "Cache ID is not supported yet";
auto *learner = static_cast<Learner *>(handle); auto *learner = static_cast<Learner *>(handle);
std::string json_str{c_json_strs}; std::string json_str{c_json_strs};
auto x = std::make_shared<data::CudfAdapter>(json_str); auto x = std::make_shared<T>(json_str);
HostDeviceVector<float> *p_predt{nullptr}; HostDeviceVector<float> *p_predt{nullptr};
std::string type { c_type }; auto type = PredictionType(get<Integer const>(config["type"]));
learner->InplacePredict(x, type, missing, &p_predt, iteration_begin, iteration_end); learner->InplacePredict(x, p_m, type, get<Number const>(config["missing"]),
&p_predt,
get<Integer const>(config["iteration_begin"]),
get<Integer const>(config["iteration_end"]));
CHECK(p_predt); CHECK(p_predt);
CHECK(p_predt->DeviceCanRead()); CHECK(p_predt->DeviceCanRead() && !p_predt->HostCanRead());
auto &shape = learner->GetThreadLocal().prediction_shape;
auto chunksize = x->NumRows() == 0 ? 0 : p_predt->Size() / x->NumRows();
bool strict_shape = get<Boolean const>(config["strict_shape"]);
CalcPredictShape(strict_shape, type, x->NumRows(), x->NumColumns(), chunksize,
learner->Groups(), learner->BoostedRounds(), &shape,
out_dim);
*out_shape = dmlc::BeginPtr(shape);
*out_result = p_predt->ConstDevicePointer(); *out_result = p_predt->ConstDevicePointer();
*out_len = static_cast<xgboost::bst_ulong>(p_predt->Size());
API_END(); API_END();
} }
// A hidden API as cache id is not being supported yet. // A hidden API as cache id is not being supported yet.
XGB_DLL int XGBoosterPredictFromArrayInterface(BoosterHandle handle, XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(
char const* c_json_strs, BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
float missing, DMatrixHandle m, xgboost::bst_ulong const **out_shape,
unsigned iteration_begin, xgboost::bst_ulong *out_dim, const float **out_result) {
unsigned iteration_end, std::shared_ptr<DMatrix> p_m {nullptr};
char const* c_type, if (m) {
xgboost::bst_ulong cache_id, p_m = *static_cast<std::shared_ptr<DMatrix> *>(m);
xgboost::bst_ulong *out_len, }
float const** out_result) { return InplacePreidctCuda<data::CudfAdapter>(
API_BEGIN(); handle, c_json_strs, c_json_config, p_m, out_shape, out_dim, out_result);
CHECK_HANDLE(); }
CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet";
auto *learner = static_cast<Learner*>(handle); // A hidden API as cache id is not being supported yet.
XGB_DLL int XGBoosterPredictFromArrayInterface(
std::string json_str{c_json_strs}; BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
auto x = std::make_shared<data::CupyAdapter>(json_str); DMatrixHandle m, xgboost::bst_ulong const **out_shape,
HostDeviceVector<float>* p_predt { nullptr }; xgboost::bst_ulong *out_dim, const float **out_result) {
std::string type { c_type }; std::shared_ptr<DMatrix> p_m {nullptr};
learner->InplacePredict(x, type, missing, &p_predt, iteration_begin, iteration_end); if (m) {
CHECK(p_predt); p_m = *static_cast<std::shared_ptr<DMatrix> *>(m);
CHECK(p_predt->DeviceCanRead()); }
return InplacePreidctCuda<data::CupyAdapter>(
*out_result = p_predt->ConstDevicePointer(); handle, c_json_strs, c_json_config, p_m, out_shape, out_dim, out_result);
*out_len = static_cast<xgboost::bst_ulong>(p_predt->Size());
API_END();
} }

114
src/c_api/c_api_utils.h Normal file
View File

@ -0,0 +1,114 @@
/*!
* Copyright (c) 2021 by XGBoost Contributors
*/
#ifndef XGBOOST_C_API_C_API_UTILS_H_
#define XGBOOST_C_API_C_API_UTILS_H_
#include <algorithm>
#include <functional>
#include <vector>
#include "xgboost/logging.h"
#include "xgboost/learner.h"
namespace xgboost {
/* \brief Determine the output shape of prediction.
*
* \param strict_shape Whether should we reshape the output with consideration of groups
* and forest.
* \param type Prediction type
* \param rows Input samples
* \param cols Input features
* \param chunksize Total elements of output / rows
* \param groups Number of output groups from Learner
* \param rounds end_iteration - beg_iteration
* \param out_shape Output shape
* \param out_dim Output dimension
*/
inline void CalcPredictShape(bool strict_shape, PredictionType type, size_t rows, size_t cols,
size_t chunksize, size_t groups, size_t rounds,
std::vector<bst_ulong> *out_shape,
xgboost::bst_ulong *out_dim) {
auto &shape = *out_shape;
if ((type == PredictionType::kMargin || type == PredictionType::kValue) &&
rows != 0) {
CHECK_EQ(chunksize, groups);
}
switch (type) {
case PredictionType::kValue:
case PredictionType::kMargin: {
if (chunksize == 1 && !strict_shape) {
*out_dim = 1;
shape.resize(*out_dim);
shape.front() = rows;
} else {
*out_dim = 2;
shape.resize(*out_dim);
shape.front() = rows;
shape.back() = groups;
}
break;
}
case PredictionType::kApproxContribution:
case PredictionType::kContribution: {
auto groups = chunksize / (cols + 1);
if (groups == 1 && !strict_shape) {
*out_dim = 2;
shape.resize(*out_dim);
shape.front() = rows;
shape.back() = cols + 1;
} else {
*out_dim = 3;
shape.resize(*out_dim);
shape[0] = rows;
shape[1] = groups;
shape[2] = cols + 1;
}
break;
}
case PredictionType::kInteraction: {
if (groups == 1 && !strict_shape) {
*out_dim = 3;
shape.resize(*out_dim);
shape[0] = rows;
shape[1] = cols + 1;
shape[2] = cols + 1;
} else {
*out_dim = 4;
shape.resize(*out_dim);
shape[0] = rows;
shape[1] = groups;
shape[2] = cols + 1;
shape[3] = cols + 1;
}
break;
}
case PredictionType::kLeaf: {
if (strict_shape) {
shape.resize(4);
shape[0] = rows;
shape[1] = rounds;
shape[2] = groups;
auto forest = chunksize / (shape[1] * shape[2]);
forest = std::max(static_cast<decltype(forest)>(1), forest);
shape[3] = forest;
*out_dim = shape.size();
} else {
*out_dim = 2;
shape.resize(*out_dim);
shape.front() = rows;
shape.back() = chunksize;
}
break;
}
default: {
LOG(FATAL) << "Unknown prediction type:" << static_cast<int>(type);
}
}
CHECK_EQ(
std::accumulate(shape.cbegin(), shape.cend(), 1, std::multiplies<>{}),
chunksize * rows);
}
} // namespace xgboost
#endif // XGBOOST_C_API_C_API_UTILS_H_

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright (c) 2019~2020 by Contributors * Copyright (c) 2019~2021 by Contributors
* \file adapter.h * \file adapter.h
*/ */
#ifndef XGBOOST_DATA_ADAPTER_H_ #ifndef XGBOOST_DATA_ADAPTER_H_
@ -228,6 +228,128 @@ class DenseAdapter : public detail::SingleBatchDataIter<DenseAdapterBatch> {
size_t num_columns_; size_t num_columns_;
}; };
class ArrayAdapterBatch : public detail::NoMetaInfo {
ArrayInterface array_interface_;
class Line {
ArrayInterface array_interface_;
size_t ridx_;
public:
Line(ArrayInterface array_interface, size_t ridx)
: array_interface_{std::move(array_interface)}, ridx_{ridx} {}
size_t Size() const { return array_interface_.num_cols; }
COOTuple GetElement(size_t idx) const {
return {ridx_, idx, array_interface_.GetElement(idx)};
}
};
public:
ArrayAdapterBatch() = default;
Line const GetLine(size_t idx) const {
auto line = array_interface_.SliceRow(idx);
return Line{line, idx};
}
explicit ArrayAdapterBatch(ArrayInterface array_interface)
: array_interface_{std::move(array_interface)} {}
};
/**
* Adapter for dense array on host, in Python that's `numpy.ndarray`. This is similar to
* `DenseAdapter`, but supports __array_interface__ instead of raw pointers. An
* advantage is this can handle various data type without making a copy.
*/
class ArrayAdapter : public detail::SingleBatchDataIter<ArrayAdapterBatch> {
public:
explicit ArrayAdapter(StringView array_interface) {
auto j = Json::Load(array_interface);
array_interface_ = ArrayInterface(get<Object const>(j));
batch_ = ArrayAdapterBatch{array_interface_};
}
ArrayAdapterBatch const& Value() const override { return batch_; }
size_t NumRows() const { return array_interface_.num_rows; }
size_t NumColumns() const { return array_interface_.num_cols; }
private:
ArrayAdapterBatch batch_;
ArrayInterface array_interface_;
};
class CSRArrayAdapterBatch : public detail::NoMetaInfo {
ArrayInterface indptr_;
ArrayInterface indices_;
ArrayInterface values_;
class Line {
ArrayInterface indices_;
ArrayInterface values_;
size_t ridx_;
public:
Line(ArrayInterface indices, ArrayInterface values, size_t ridx)
: indices_{std::move(indices)}, values_{std::move(values)}, ridx_{ridx} {}
COOTuple GetElement(size_t idx) const {
return {ridx_, indices_.GetElement<size_t>(idx), values_.GetElement(idx)};
}
size_t Size() const {
return values_.num_rows * values_.num_cols;
}
};
public:
CSRArrayAdapterBatch() = default;
CSRArrayAdapterBatch(ArrayInterface indptr, ArrayInterface indices,
ArrayInterface values)
: indptr_{std::move(indptr)}, indices_{std::move(indices)},
values_{std::move(values)} {}
Line const GetLine(size_t idx) const {
auto begin_offset = indptr_.GetElement<size_t>(idx);
auto end_offset = indptr_.GetElement<size_t>(idx + 1);
auto indices = indices_.SliceOffset(begin_offset);
auto values = values_.SliceOffset(begin_offset);
values.num_cols = end_offset - begin_offset;
values.num_rows = 1;
indices.num_cols = values.num_cols;
indices.num_rows = values.num_rows;
return Line{indices, values, idx};
}
};
/**
* Adapter for CSR array on host, in Python that's `scipy.sparse.csr_matrix`. This is
* similar to `CSRAdapter`, but supports __array_interface__ instead of raw pointers. An
* advantage is this can handle various data type without making a copy.
*/
class CSRArrayAdapter : public detail::SingleBatchDataIter<CSRArrayAdapterBatch> {
public:
CSRArrayAdapter(StringView indptr, StringView indices, StringView values,
size_t num_cols)
: indptr_{indptr}, indices_{indices}, values_{values}, num_cols_{num_cols} {
batch_ = CSRArrayAdapterBatch{indptr_, indices_, values_};
}
CSRArrayAdapterBatch const& Value() const override {
return batch_;
}
size_t NumRows() const {
size_t size = indptr_.num_cols * indptr_.num_rows;
size = size == 0 ? 0 : size - 1;
return size;
}
size_t NumColumns() const { return num_cols_; }
private:
CSRArrayAdapterBatch batch_;
ArrayInterface indptr_;
ArrayInterface indices_;
ArrayInterface values_;
size_t num_cols_;
};
class CSCAdapterBatch : public detail::NoMetaInfo { class CSCAdapterBatch : public detail::NoMetaInfo {
public: public:
CSCAdapterBatch(const size_t* col_ptr, const unsigned* row_idx, CSCAdapterBatch(const size_t* col_ptr, const unsigned* row_idx,

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2019 by Contributors * Copyright 2019-2021 by Contributors
* \file array_interface.h * \file array_interface.h
* \brief View of __array_interface__ * \brief View of __array_interface__
*/ */
@ -87,7 +87,7 @@ struct ArrayInterfaceErrors {
} }
} }
static std::string UnSupportedType(const char (&typestr)[3]) { static std::string UnSupportedType(StringView typestr) {
return TypeStr(typestr[1]) + " is not supported."; return TypeStr(typestr[1]) + " is not supported.";
} }
}; };
@ -210,6 +210,7 @@ class ArrayInterfaceHandler {
static_cast<bst_feature_t>(get<Integer const>(j_shape.at(1)))}; static_cast<bst_feature_t>(get<Integer const>(j_shape.at(1)))};
} }
} }
template <typename T> template <typename T>
static common::Span<T> ExtractData(std::map<std::string, Json> const& column) { static common::Span<T> ExtractData(std::map<std::string, Json> const& column) {
Validate(column); Validate(column);
@ -257,16 +258,24 @@ class ArrayInterface {
} }
auto typestr = get<String const>(column.at("typestr")); auto typestr = get<String const>(column.at("typestr"));
type[0] = typestr.at(0); this->AssignType(StringView{typestr});
type[1] = typestr.at(1);
type[2] = typestr.at(2);
this->CheckType();
} }
public:
enum Type : std::int8_t { kF4, kF8, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 };
public: public:
ArrayInterface() = default; ArrayInterface() = default;
explicit ArrayInterface(std::string const& str, bool allow_mask = true) { explicit ArrayInterface(std::string const &str, bool allow_mask = true)
auto jinterface = Json::Load({str.c_str(), str.size()}); : ArrayInterface{StringView{str.c_str(), str.size()}, allow_mask} {}
explicit ArrayInterface(std::map<std::string, Json> const &column,
bool allow_mask = true) {
this->Initialize(column, allow_mask);
}
explicit ArrayInterface(StringView str, bool allow_mask = true) {
auto jinterface = Json::Load(str);
if (IsA<Object>(jinterface)) { if (IsA<Object>(jinterface)) {
this->Initialize(get<Object const>(jinterface), allow_mask); this->Initialize(get<Object const>(jinterface), allow_mask);
return; return;
@ -279,71 +288,114 @@ class ArrayInterface {
} }
} }
explicit ArrayInterface(std::map<std::string, Json> const &column, void AssignType(StringView typestr) {
bool allow_mask = true) { if (typestr[1] == 'f' && typestr[2] == '4') {
this->Initialize(column, allow_mask); type = kF4;
} } else if (typestr[1] == 'f' && typestr[2] == '8') {
type = kF8;
void CheckType() const { } else if (typestr[1] == 'i' && typestr[2] == '1') {
if (type[1] == 'f' && type[2] == '4') { type = kI1;
return; } else if (typestr[1] == 'i' && typestr[2] == '2') {
} else if (type[1] == 'f' && type[2] == '8') { type = kI2;
return; } else if (typestr[1] == 'i' && typestr[2] == '4') {
} else if (type[1] == 'i' && type[2] == '1') { type = kI4;
return; } else if (typestr[1] == 'i' && typestr[2] == '8') {
} else if (type[1] == 'i' && type[2] == '2') { type = kI8;
return; } else if (typestr[1] == 'u' && typestr[2] == '1') {
} else if (type[1] == 'i' && type[2] == '4') { type = kU1;
return; } else if (typestr[1] == 'u' && typestr[2] == '2') {
} else if (type[1] == 'i' && type[2] == '8') { type = kU2;
return; } else if (typestr[1] == 'u' && typestr[2] == '4') {
} else if (type[1] == 'u' && type[2] == '1') { type = kU4;
return; } else if (typestr[1] == 'u' && typestr[2] == '8') {
} else if (type[1] == 'u' && type[2] == '2') { type = kU8;
return;
} else if (type[1] == 'u' && type[2] == '4') {
return;
} else if (type[1] == 'u' && type[2] == '8') {
return;
} else { } else {
LOG(FATAL) << ArrayInterfaceErrors::UnSupportedType(type); LOG(FATAL) << ArrayInterfaceErrors::UnSupportedType(typestr);
return; return;
} }
} }
XGBOOST_DEVICE float GetElement(size_t idx) const { XGBOOST_DEVICE ArrayInterface SliceOffset(size_t offset) const {
void* p_values;
switch (type) {
case kF4:
p_values = reinterpret_cast<float *>(data) + offset;
break;
case kF8:
p_values = reinterpret_cast<double *>(data) + offset;
break;
case kI1:
p_values = reinterpret_cast<int8_t *>(data) + offset;
break;
case kI2:
p_values = reinterpret_cast<int16_t *>(data) + offset;
break;
case kI4:
p_values = reinterpret_cast<int32_t *>(data) + offset;
break;
case kI8:
p_values = reinterpret_cast<int64_t *>(data) + offset;
break;
case kU1:
p_values = reinterpret_cast<uint8_t *>(data) + offset;
break;
case kU2:
p_values = reinterpret_cast<uint16_t *>(data) + offset;
break;
case kU4:
p_values = reinterpret_cast<uint32_t *>(data) + offset;
break;
case kU8:
p_values = reinterpret_cast<uint64_t *>(data) + offset;
break;
}
ArrayInterface ret = *this;
ret.data = p_values;
return ret;
}
XGBOOST_DEVICE ArrayInterface SliceRow(size_t idx) const {
size_t offset = idx * num_cols;
auto ret = this->SliceOffset(offset);
ret.num_rows = 1;
return ret;
}
template <typename T = float>
XGBOOST_DEVICE T GetElement(size_t idx) const {
SPAN_CHECK(idx < num_cols * num_rows); SPAN_CHECK(idx < num_cols * num_rows);
if (type[1] == 'f' && type[2] == '4') { switch (type) {
case kF4:
return reinterpret_cast<float*>(data)[idx]; return reinterpret_cast<float*>(data)[idx];
} else if (type[1] == 'f' && type[2] == '8') { case kF8:
return reinterpret_cast<double*>(data)[idx]; return reinterpret_cast<double*>(data)[idx];
} else if (type[1] == 'i' && type[2] == '1') { case kI1:
return reinterpret_cast<int8_t*>(data)[idx]; return reinterpret_cast<int8_t*>(data)[idx];
} else if (type[1] == 'i' && type[2] == '2') { case kI2:
return reinterpret_cast<int16_t*>(data)[idx]; return reinterpret_cast<int16_t*>(data)[idx];
} else if (type[1] == 'i' && type[2] == '4') { case kI4:
return reinterpret_cast<int32_t*>(data)[idx]; return reinterpret_cast<int32_t*>(data)[idx];
} else if (type[1] == 'i' && type[2] == '8') { case kI8:
return reinterpret_cast<int64_t*>(data)[idx]; return reinterpret_cast<int64_t*>(data)[idx];
} else if (type[1] == 'u' && type[2] == '1') { case kU1:
return reinterpret_cast<uint8_t*>(data)[idx]; return reinterpret_cast<uint8_t*>(data)[idx];
} else if (type[1] == 'u' && type[2] == '2') { case kU2:
return reinterpret_cast<uint16_t*>(data)[idx]; return reinterpret_cast<uint16_t*>(data)[idx];
} else if (type[1] == 'u' && type[2] == '4') { case kU4:
return reinterpret_cast<uint32_t*>(data)[idx]; return reinterpret_cast<uint32_t*>(data)[idx];
} else if (type[1] == 'u' && type[2] == '8') { case kU8:
return reinterpret_cast<uint64_t*>(data)[idx]; return reinterpret_cast<uint64_t*>(data)[idx];
} else {
SPAN_CHECK(false);
return 0;
} }
SPAN_CHECK(false);
return reinterpret_cast<float*>(data)[idx];
} }
RBitField8 valid; RBitField8 valid;
bst_row_t num_rows; bst_row_t num_rows;
bst_feature_t num_cols; bst_feature_t num_cols;
void* data; void* data;
char type[3];
Type type;
}; };
} // namespace xgboost } // namespace xgboost

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2019 by XGBoost Contributors * Copyright 2019-2021 by XGBoost Contributors
* *
* \file data.cu * \file data.cu
* \brief Handles setting metainfo from array interface. * \brief Handles setting metainfo from array interface.
@ -45,15 +45,15 @@ auto SetDeviceToPtr(void *ptr) {
} // anonymous namespace } // anonymous namespace
void CopyGroupInfoImpl(ArrayInterface column, std::vector<bst_group_t>* out) { void CopyGroupInfoImpl(ArrayInterface column, std::vector<bst_group_t>* out) {
CHECK(column.type[1] == 'i' || column.type[1] == 'u') CHECK(column.type != ArrayInterface::kF4 && column.type != ArrayInterface::kF8)
<< "Expected integer metainfo"; << "Expected integer for group info.";
auto ptr_device = SetDeviceToPtr(column.data); auto ptr_device = SetDeviceToPtr(column.data);
dh::TemporaryArray<bst_group_t> temp(column.num_rows); dh::TemporaryArray<bst_group_t> temp(column.num_rows);
auto d_tmp = temp.data(); auto d_tmp = temp.data();
dh::LaunchN(ptr_device, column.num_rows, [=] __device__(size_t idx) { dh::LaunchN(ptr_device, column.num_rows, [=] __device__(size_t idx) {
d_tmp[idx] = column.GetElement(idx); d_tmp[idx] = column.GetElement<size_t>(idx);
}); });
auto length = column.num_rows; auto length = column.num_rows;
out->resize(length + 1); out->resize(length + 1);
@ -103,15 +103,15 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
auto it = dh::MakeTransformIterator<uint32_t>( auto it = dh::MakeTransformIterator<uint32_t>(
thrust::make_counting_iterator(0ul), thrust::make_counting_iterator(0ul),
[array_interface] __device__(size_t i) { [array_interface] __device__(size_t i) {
return static_cast<uint32_t>(array_interface.GetElement(i)); return array_interface.GetElement<uint32_t>(i);
}); });
dh::caching_device_vector<bool> flag(1); dh::caching_device_vector<bool> flag(1);
auto d_flag = dh::ToSpan(flag); auto d_flag = dh::ToSpan(flag);
auto d = SetDeviceToPtr(array_interface.data); auto d = SetDeviceToPtr(array_interface.data);
dh::LaunchN(d, 1, [=] __device__(size_t) { d_flag[0] = true; }); dh::LaunchN(d, 1, [=] __device__(size_t) { d_flag[0] = true; });
dh::LaunchN(d, array_interface.num_rows - 1, [=] __device__(size_t i) { dh::LaunchN(d, array_interface.num_rows - 1, [=] __device__(size_t i) {
if (static_cast<uint32_t>(array_interface.GetElement(i)) > if (array_interface.GetElement<uint32_t>(i) >
static_cast<uint32_t>(array_interface.GetElement(i + 1))) { array_interface.GetElement<uint32_t>(i + 1)) {
d_flag[0] = false; d_flag[0] = false;
} }
}); });

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2014-2020 by Contributors * Copyright 2014-2021 by Contributors
* \file gbtree.cc * \file gbtree.cc
* \brief gradient boosted tree implementation. * \brief gradient boosted tree implementation.
* \author Tianqi Chen * \author Tianqi Chen
@ -265,15 +265,34 @@ class GBTree : public GradientBooster {
bool training, bool training,
unsigned ntree_limit) override; unsigned ntree_limit) override;
void InplacePredict(dmlc::any const &x, float missing, void InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
PredictionCacheEntry *out_preds, float missing, PredictionCacheEntry *out_preds,
uint32_t layer_begin, uint32_t layer_begin, unsigned layer_end) const override {
unsigned layer_end) const override {
CHECK(configured_); CHECK(configured_);
uint32_t tree_begin, tree_end; uint32_t tree_begin, tree_end;
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end); std::tie(tree_begin, tree_end) =
this->GetPredictor()->InplacePredict(x, model_, missing, out_preds, detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
tree_begin, tree_end); std::vector<Predictor const *> predictors{
cpu_predictor_.get(),
#if defined(XGBOOST_USE_CUDA)
gpu_predictor_.get()
#endif // defined(XGBOOST_USE_CUDA)
};
StringView msg{"Unsupported data type for inplace predict."};
if (tparam_.predictor == PredictorType::kAuto) {
// Try both predictor implementations
for (auto const &p : predictors) {
if (p && p->InplacePredict(x, p_m, model_, missing, out_preds,
tree_begin, tree_end)) {
return;
}
}
LOG(FATAL) << msg;
} else {
bool success = this->GetPredictor()->InplacePredict(
x, p_m, model_, missing, out_preds, tree_begin, tree_end);
CHECK(success) << msg;
}
} }
void PredictInstance(const SparsePage::Inst& inst, void PredictInstance(const SparsePage::Inst& inst,

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2014-2020 by Contributors * Copyright 2014-2021 by Contributors
* \file learner.cc * \file learner.cc
* \brief Implementation of learning algorithm. * \brief Implementation of learning algorithm.
* \author Tianqi Chen * \author Tianqi Chen
@ -1110,23 +1110,30 @@ class LearnerImpl : public LearnerIO {
CHECK(!this->need_configuration_); CHECK(!this->need_configuration_);
return this->gbm_->BoostedRounds(); return this->gbm_->BoostedRounds();
} }
uint32_t Groups() const override {
CHECK(!this->need_configuration_);
return this->learner_model_param_.num_output_group;
}
XGBAPIThreadLocalEntry& GetThreadLocal() const override { XGBAPIThreadLocalEntry& GetThreadLocal() const override {
return (*LearnerAPIThreadLocalStore::Get())[this]; return (*LearnerAPIThreadLocalStore::Get())[this];
} }
void InplacePredict(dmlc::any const &x, std::string const &type, void InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
float missing, HostDeviceVector<bst_float> **out_preds, PredictionType type, float missing,
uint32_t layer_begin, uint32_t layer_end) override { HostDeviceVector<bst_float> **out_preds,
uint32_t iteration_begin,
uint32_t iteration_end) override {
this->Configure(); this->Configure();
auto& out_predictions = this->GetThreadLocal().prediction_entry; auto& out_predictions = this->GetThreadLocal().prediction_entry;
this->gbm_->InplacePredict(x, missing, &out_predictions, layer_begin, this->gbm_->InplacePredict(x, p_m, missing, &out_predictions,
layer_end); iteration_begin, iteration_end);
if (type == "value") { if (type == PredictionType::kValue) {
obj_->PredTransform(&out_predictions.predictions); obj_->PredTransform(&out_predictions.predictions);
} else if (type == "margin") { } else if (type == PredictionType::kMargin) {
// do nothing
} else { } else {
LOG(FATAL) << "Unsupported prediction type:" << type; LOG(FATAL) << "Unsupported prediction type:" << static_cast<int>(type);
} }
*out_preds = &out_predictions.predictions; *out_preds = &out_predictions.predictions;
} }

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright by Contributors 2017-2020 * Copyright by Contributors 2017-2021
*/ */
#include <dmlc/omp.h> #include <dmlc/omp.h>
#include <dmlc/any.h> #include <dmlc/any.h>
@ -287,7 +287,7 @@ class CPUPredictor : public Predictor {
} }
template <typename Adapter> template <typename Adapter>
void DispatchedInplacePredict(dmlc::any const &x, void DispatchedInplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
const gbm::GBTreeModel &model, float missing, const gbm::GBTreeModel &model, float missing,
PredictionCacheEntry *out_preds, PredictionCacheEntry *out_preds,
uint32_t tree_begin, uint32_t tree_end) const { uint32_t tree_begin, uint32_t tree_end) const {
@ -295,33 +295,44 @@ class CPUPredictor : public Predictor {
auto m = dmlc::get<std::shared_ptr<Adapter>>(x); auto m = dmlc::get<std::shared_ptr<Adapter>>(x);
CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature) CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature)
<< "Number of columns in data must equal to trained model."; << "Number of columns in data must equal to trained model.";
if (p_m) {
p_m->Info().num_row_ = m->NumRows();
this->InitOutPredictions(p_m->Info(), &(out_preds->predictions), model);
} else {
MetaInfo info; MetaInfo info;
info.num_col_ = m->NumColumns();
info.num_row_ = m->NumRows(); info.num_row_ = m->NumRows();
this->InitOutPredictions(info, &(out_preds->predictions), model); this->InitOutPredictions(info, &(out_preds->predictions), model);
std::vector<Entry> workspace(info.num_col_ * 8 * threads); }
std::vector<Entry> workspace(m->NumColumns() * 8 * threads);
auto &predictions = out_preds->predictions.HostVector(); auto &predictions = out_preds->predictions.HostVector();
std::vector<RegTree::FVec> thread_temp; std::vector<RegTree::FVec> thread_temp;
InitThreadTemp(threads*kBlockOfRowsSize, model.learner_model_param->num_feature, InitThreadTemp(threads * kBlockOfRowsSize,
&thread_temp); model.learner_model_param->num_feature, &thread_temp);
PredictBatchByBlockOfRowsKernel<AdapterView<Adapter>, PredictBatchByBlockOfRowsKernel<AdapterView<Adapter>, kBlockOfRowsSize>(
kBlockOfRowsSize>(AdapterView<Adapter>( AdapterView<Adapter>(m.get(), missing, common::Span<Entry>{workspace}),
m.get(), missing, common::Span<Entry>{workspace}),
&predictions, model, tree_begin, tree_end, &thread_temp); &predictions, model, tree_begin, tree_end, &thread_temp);
} }
void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model, bool InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
float missing, PredictionCacheEntry *out_preds, const gbm::GBTreeModel &model, float missing,
uint32_t tree_begin, unsigned tree_end) const override { PredictionCacheEntry *out_preds, uint32_t tree_begin,
unsigned tree_end) const override {
if (x.type() == typeid(std::shared_ptr<data::DenseAdapter>)) { if (x.type() == typeid(std::shared_ptr<data::DenseAdapter>)) {
this->DispatchedInplacePredict<data::DenseAdapter>( this->DispatchedInplacePredict<data::DenseAdapter>(
x, model, missing, out_preds, tree_begin, tree_end); x, p_m, model, missing, out_preds, tree_begin, tree_end);
} else if (x.type() == typeid(std::shared_ptr<data::CSRAdapter>)) { } else if (x.type() == typeid(std::shared_ptr<data::CSRAdapter>)) {
this->DispatchedInplacePredict<data::CSRAdapter>( this->DispatchedInplacePredict<data::CSRAdapter>(
x, model, missing, out_preds, tree_begin, tree_end); x, p_m, model, missing, out_preds, tree_begin, tree_end);
} else if (x.type() == typeid(std::shared_ptr<data::ArrayAdapter>)) {
this->DispatchedInplacePredict<data::ArrayAdapter> (
x, p_m, model, missing, out_preds, tree_begin, tree_end);
} else if (x.type() == typeid(std::shared_ptr<data::CSRArrayAdapter>)) {
this->DispatchedInplacePredict<data::CSRArrayAdapter> (
x, p_m, model, missing, out_preds, tree_begin, tree_end);
} else { } else {
LOG(FATAL) << "Data type is not supported by CPU Predictor."; return false;
} }
return true;
} }
void PredictInstance(const SparsePage::Inst& inst, void PredictInstance(const SparsePage::Inst& inst,

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2017-2020 by Contributors * Copyright 2017-2021 by Contributors
*/ */
#include <thrust/copy.h> #include <thrust/copy.h>
#include <thrust/device_ptr.h> #include <thrust/device_ptr.h>
@ -644,7 +644,7 @@ class GPUPredictor : public xgboost::Predictor {
} }
template <typename Adapter, typename Loader> template <typename Adapter, typename Loader>
void DispatchedInplacePredict(dmlc::any const &x, void DispatchedInplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
const gbm::GBTreeModel &model, float, const gbm::GBTreeModel &model, float,
PredictionCacheEntry *out_preds, PredictionCacheEntry *out_preds,
uint32_t tree_begin, uint32_t tree_end) const { uint32_t tree_begin, uint32_t tree_end) const {
@ -659,16 +659,20 @@ class GPUPredictor : public xgboost::Predictor {
CHECK_EQ(this->generic_param_->gpu_id, m->DeviceIdx()) CHECK_EQ(this->generic_param_->gpu_id, m->DeviceIdx())
<< "XGBoost is running on device: " << this->generic_param_->gpu_id << ", " << "XGBoost is running on device: " << this->generic_param_->gpu_id << ", "
<< "but data is on: " << m->DeviceIdx(); << "but data is on: " << m->DeviceIdx();
if (p_m) {
p_m->Info().num_row_ = m->NumRows();
this->InitOutPredictions(p_m->Info(), &(out_preds->predictions), model);
} else {
MetaInfo info; MetaInfo info;
info.num_col_ = m->NumColumns();
info.num_row_ = m->NumRows(); info.num_row_ = m->NumRows();
this->InitOutPredictions(info, &(out_preds->predictions), model); this->InitOutPredictions(info, &(out_preds->predictions), model);
}
const uint32_t BLOCK_THREADS = 128; const uint32_t BLOCK_THREADS = 128;
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(info.num_row_, BLOCK_THREADS)); auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(m->NumRows(), BLOCK_THREADS));
size_t shared_memory_bytes = size_t shared_memory_bytes =
SharedMemoryBytes<BLOCK_THREADS>(info.num_col_, max_shared_memory_bytes); SharedMemoryBytes<BLOCK_THREADS>(m->NumColumns(), max_shared_memory_bytes);
bool use_shared = shared_memory_bytes != 0; bool use_shared = shared_memory_bytes != 0;
size_t entry_start = 0; size_t entry_start = 0;
@ -680,23 +684,25 @@ class GPUPredictor : public xgboost::Predictor {
d_model.categories_tree_segments.ConstDeviceSpan(), d_model.categories_tree_segments.ConstDeviceSpan(),
d_model.categories_node_segments.ConstDeviceSpan(), d_model.categories_node_segments.ConstDeviceSpan(),
d_model.categories.ConstDeviceSpan(), tree_begin, tree_end, m->NumColumns(), d_model.categories.ConstDeviceSpan(), tree_begin, tree_end, m->NumColumns(),
info.num_row_, entry_start, use_shared, output_groups); m->NumRows(), entry_start, use_shared, output_groups);
} }
void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model, bool InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
float missing, PredictionCacheEntry *out_preds, const gbm::GBTreeModel &model, float missing,
uint32_t tree_begin, unsigned tree_end) const override { PredictionCacheEntry *out_preds, uint32_t tree_begin,
unsigned tree_end) const override {
if (x.type() == typeid(std::shared_ptr<data::CupyAdapter>)) { if (x.type() == typeid(std::shared_ptr<data::CupyAdapter>)) {
this->DispatchedInplacePredict< this->DispatchedInplacePredict<
data::CupyAdapter, DeviceAdapterLoader<data::CupyAdapterBatch>>( data::CupyAdapter, DeviceAdapterLoader<data::CupyAdapterBatch>>(
x, model, missing, out_preds, tree_begin, tree_end); x, p_m, model, missing, out_preds, tree_begin, tree_end);
} else if (x.type() == typeid(std::shared_ptr<data::CudfAdapter>)) { } else if (x.type() == typeid(std::shared_ptr<data::CudfAdapter>)) {
this->DispatchedInplacePredict< this->DispatchedInplacePredict<
data::CudfAdapter, DeviceAdapterLoader<data::CudfAdapterBatch>>( data::CudfAdapter, DeviceAdapterLoader<data::CudfAdapterBatch>>(
x, model, missing, out_preds, tree_begin, tree_end); x, p_m, model, missing, out_preds, tree_begin, tree_end);
} else { } else {
LOG(FATAL) << "Only CuPy and CuDF are supported by GPU Predictor."; return false;
} }
return true;
} }
void PredictContribution(DMatrix* p_fmat, void PredictContribution(DMatrix* p_fmat,

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2020 by Contributors * Copyright 2020-2021 by Contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
@ -104,21 +104,24 @@ void TestInplacePrediction(dmlc::any x, std::string predictor,
} }
HostDeviceVector<float> *p_out_predictions_0{nullptr}; HostDeviceVector<float> *p_out_predictions_0{nullptr};
learner->InplacePredict(x, "margin", std::numeric_limits<float>::quiet_NaN(), learner->InplacePredict(x, nullptr, PredictionType::kMargin,
std::numeric_limits<float>::quiet_NaN(),
&p_out_predictions_0, 0, 2); &p_out_predictions_0, 0, 2);
CHECK(p_out_predictions_0); CHECK(p_out_predictions_0);
HostDeviceVector<float> predict_0 (p_out_predictions_0->Size()); HostDeviceVector<float> predict_0 (p_out_predictions_0->Size());
predict_0.Copy(*p_out_predictions_0); predict_0.Copy(*p_out_predictions_0);
HostDeviceVector<float> *p_out_predictions_1{nullptr}; HostDeviceVector<float> *p_out_predictions_1{nullptr};
learner->InplacePredict(x, "margin", std::numeric_limits<float>::quiet_NaN(), learner->InplacePredict(x, nullptr, PredictionType::kMargin,
std::numeric_limits<float>::quiet_NaN(),
&p_out_predictions_1, 2, 4); &p_out_predictions_1, 2, 4);
CHECK(p_out_predictions_1); CHECK(p_out_predictions_1);
HostDeviceVector<float> predict_1 (p_out_predictions_1->Size()); HostDeviceVector<float> predict_1 (p_out_predictions_1->Size());
predict_1.Copy(*p_out_predictions_1); predict_1.Copy(*p_out_predictions_1);
HostDeviceVector<float>* p_out_predictions{nullptr}; HostDeviceVector<float>* p_out_predictions{nullptr};
learner->InplacePredict(x, "margin", std::numeric_limits<float>::quiet_NaN(), learner->InplacePredict(x, nullptr, PredictionType::kMargin,
std::numeric_limits<float>::quiet_NaN(),
&p_out_predictions, 0, 4); &p_out_predictions, 0, 4);
auto& h_pred = p_out_predictions->HostVector(); auto& h_pred = p_out_predictions->HostVector();

View File

@ -11,8 +11,7 @@ import testing as tm
class TestDeviceQuantileDMatrix: class TestDeviceQuantileDMatrix:
def test_dmatrix_numpy_init(self): def test_dmatrix_numpy_init(self):
data = np.random.randn(5, 5) data = np.random.randn(5, 5)
with pytest.raises(TypeError, with pytest.raises(TypeError, match='is not supported'):
match='is not supported for DeviceQuantileDMatrix'):
xgb.DeviceQuantileDMatrix(data, np.ones(5, dtype=np.float64)) xgb.DeviceQuantileDMatrix(data, np.ones(5, dtype=np.float64))
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())

View File

@ -141,6 +141,13 @@ class TestGPUPredict:
assert np.allclose(cpu_train_score, gpu_train_score) assert np.allclose(cpu_train_score, gpu_train_score)
assert np.allclose(cpu_test_score, gpu_test_score) assert np.allclose(cpu_test_score, gpu_test_score)
def run_inplace_base_margin(self, booster, dtrain, X, base_margin):
import cupy as cp
dtrain.set_info(base_margin=base_margin)
from_inplace = booster.inplace_predict(data=X, base_margin=base_margin)
from_dmatrix = booster.predict(dtrain)
cp.testing.assert_allclose(from_inplace, from_dmatrix)
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
def test_inplace_predict_cupy(self): def test_inplace_predict_cupy(self):
import cupy as cp import cupy as cp
@ -175,6 +182,9 @@ class TestGPUPredict:
for i in range(10): for i in range(10):
run_threaded_predict(X, rows, predict_dense) run_threaded_predict(X, rows, predict_dense)
base_margin = cp_rng.randn(rows)
self.run_inplace_base_margin(booster, dtrain, X, base_margin)
@pytest.mark.skipif(**tm.no_cudf()) @pytest.mark.skipif(**tm.no_cudf())
def test_inplace_predict_cudf(self): def test_inplace_predict_cudf(self):
import cupy as cp import cupy as cp
@ -208,6 +218,9 @@ class TestGPUPredict:
for i in range(10): for i in range(10):
run_threaded_predict(X, rows, predict_df) run_threaded_predict(X, rows, predict_df)
base_margin = cudf.Series(rng.randn(rows))
self.run_inplace_base_margin(booster, dtrain, X, base_margin)
@given(strategies.integers(1, 10), @given(strategies.integers(1, 10),
tm.dataset_strategy, shap_parameter_strategy) tm.dataset_strategy, shap_parameter_strategy)
@settings(deadline=None) @settings(deadline=None)

View File

@ -80,20 +80,28 @@ def test_predict_leaf():
class TestInplacePredict: class TestInplacePredict:
'''Tests for running inplace prediction''' '''Tests for running inplace prediction'''
def test_predict(self): @classmethod
rows = 1000 def setup_class(cls):
cols = 10 cls.rows = 100
cls.cols = 10
np.random.seed(1994) cls.rng = np.random.RandomState(1994)
X = np.random.randn(rows, cols) cls.X = cls.rng.randn(cls.rows, cls.cols)
y = np.random.randn(rows) cls.y = cls.rng.randn(cls.rows)
dtrain = xgb.DMatrix(X, y)
booster = xgb.train({'tree_method': 'hist'}, dtrain = xgb.DMatrix(cls.X, cls.y)
cls.booster = xgb.train({'tree_method': 'hist'},
dtrain, num_boost_round=10) dtrain, num_boost_round=10)
test = xgb.DMatrix(X[:10, ...]) cls.test = xgb.DMatrix(cls.X[:10, ...])
def test_predict(self):
booster = self.booster
X = self.X
test = self.test
predt_from_array = booster.inplace_predict(X[:10, ...]) predt_from_array = booster.inplace_predict(X[:10, ...])
predt_from_dmatrix = booster.predict(test) predt_from_dmatrix = booster.predict(test)
@ -111,7 +119,7 @@ class TestInplacePredict:
return np.all(copied_predt == inplace_predt) return np.all(copied_predt == inplace_predt)
for i in range(10): for i in range(10):
run_threaded_predict(X, rows, predict_dense) run_threaded_predict(X, self.rows, predict_dense)
def predict_csr(x): def predict_csr(x):
inplace_predt = booster.inplace_predict(sparse.csr_matrix(x)) inplace_predt = booster.inplace_predict(sparse.csr_matrix(x))
@ -120,4 +128,14 @@ class TestInplacePredict:
return np.all(copied_predt == inplace_predt) return np.all(copied_predt == inplace_predt)
for i in range(10): for i in range(10):
run_threaded_predict(X, rows, predict_csr) run_threaded_predict(X, self.rows, predict_csr)
def test_base_margin(self):
booster = self.booster
base_margin = self.rng.randn(self.rows)
from_inplace = booster.inplace_predict(data=self.X, base_margin=base_margin)
dtrain = xgb.DMatrix(self.X, self.y, base_margin=base_margin)
from_dmatrix = booster.predict(dtrain)
np.testing.assert_allclose(from_dmatrix, from_inplace)

View File

@ -1,6 +1,5 @@
import numpy as np import numpy as np
import xgboost as xgb import xgboost as xgb
import testing as tm
import pytest import pytest
try: try: