Enhance inplace prediction. (#6653)
* Accept array interface for csr and array. * Accept an optional proxy dmatrix for metainfo. This constructs an explicit `_ProxyDMatrix` type in Python. * Remove unused doc. * Add strict output.
This commit is contained in:
parent
87ab1ad607
commit
411592a347
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014-2020 by Contributors
|
||||
* Copyright 2014-2021 by Contributors
|
||||
* \file gbm.h
|
||||
* \brief Interface of gradient booster,
|
||||
* that learns through gradient statistics.
|
||||
@ -118,7 +118,7 @@ class GradientBooster : public Model, public Configurable {
|
||||
* \param layer_begin (Optional) Begining of boosted tree layer used for prediction.
|
||||
* \param layer_end (Optional) End of booster layer. 0 means do not limit trees.
|
||||
*/
|
||||
virtual void InplacePredict(dmlc::any const &, float,
|
||||
virtual void InplacePredict(dmlc::any const &, std::shared_ptr<DMatrix>, float,
|
||||
PredictionCacheEntry*,
|
||||
uint32_t,
|
||||
uint32_t) const {
|
||||
|
||||
@ -308,6 +308,7 @@ struct StringView {
|
||||
public:
|
||||
StringView() = default;
|
||||
StringView(CharT const* str, size_t size) : str_{str}, size_{size} {}
|
||||
explicit StringView(std::string const& str): str_{str.c_str()}, size_{str.size()} {}
|
||||
explicit StringView(CharT const* str) : str_{str}, size_{Traits::length(str)} {}
|
||||
|
||||
CharT const& operator[](size_t p) const { return str_[p]; }
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2015-2020 by Contributors
|
||||
* Copyright 2015-2021 by Contributors
|
||||
* \file learner.h
|
||||
* \brief Learner interface that integrates objective, gbm and evaluation together.
|
||||
* This is the user facing XGBoost training module.
|
||||
@ -30,6 +30,15 @@ class ObjFunction;
|
||||
class DMatrix;
|
||||
class Json;
|
||||
|
||||
enum class PredictionType : std::uint8_t { // NOLINT
|
||||
kValue = 0,
|
||||
kMargin = 1,
|
||||
kContribution = 2,
|
||||
kApproxContribution = 3,
|
||||
kInteraction = 4,
|
||||
kLeaf = 5
|
||||
};
|
||||
|
||||
/*! \brief entry to to easily hold returning information */
|
||||
struct XGBAPIThreadLocalEntry {
|
||||
/*! \brief result holder for returning string */
|
||||
@ -42,7 +51,10 @@ struct XGBAPIThreadLocalEntry {
|
||||
std::vector<bst_float> ret_vec_float;
|
||||
/*! \brief temp variable of gradient pairs. */
|
||||
std::vector<GradientPair> tmp_gpair;
|
||||
/*! \brief Temp variable for returing prediction result. */
|
||||
PredictionCacheEntry prediction_entry;
|
||||
/*! \brief Temp variable for returing prediction shape. */
|
||||
std::vector<bst_ulong> prediction_shape;
|
||||
};
|
||||
|
||||
/*!
|
||||
@ -123,13 +135,17 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
|
||||
* \brief Inplace prediction.
|
||||
*
|
||||
* \param x A type erased data adapter.
|
||||
* \param p_m An optional Proxy DMatrix object storing meta info like
|
||||
* base margin. Can be nullptr.
|
||||
* \param type Prediction type.
|
||||
* \param missing Missing value in the data.
|
||||
* \param [in,out] out_preds Pointer to output prediction vector.
|
||||
* \param layer_begin (Optional) Begining of boosted tree layer used for prediction.
|
||||
* \param layer_end (Optional) End of booster layer. 0 means do not limit trees.
|
||||
* \param layer_begin Begining of boosted tree layer used for prediction.
|
||||
* \param layer_end End of booster layer. 0 means do not limit trees.
|
||||
*/
|
||||
virtual void InplacePredict(dmlc::any const& x, std::string const& type,
|
||||
virtual void InplacePredict(dmlc::any const &x,
|
||||
std::shared_ptr<DMatrix> p_m,
|
||||
PredictionType type,
|
||||
float missing,
|
||||
HostDeviceVector<bst_float> **out_preds,
|
||||
uint32_t layer_begin, uint32_t layer_end) = 0;
|
||||
@ -138,6 +154,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
|
||||
* \brief Get number of boosted rounds from gradient booster.
|
||||
*/
|
||||
virtual int32_t BoostedRounds() const = 0;
|
||||
virtual uint32_t Groups() const = 0;
|
||||
|
||||
void LoadModel(Json const& in) override = 0;
|
||||
void SaveModel(Json* out) const override = 0;
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2017-2020 by Contributors
|
||||
* Copyright 2017-2021 by Contributors
|
||||
* \file predictor.h
|
||||
* \brief Interface of predictor,
|
||||
* performs predictions for a gradient booster.
|
||||
@ -142,10 +142,14 @@ class Predictor {
|
||||
* \param [in,out] out_preds The output preds.
|
||||
* \param tree_begin (Optional) Begining of boosted trees used for prediction.
|
||||
* \param tree_end (Optional) End of booster trees. 0 means do not limit trees.
|
||||
*
|
||||
* \return True if the data can be handled by current predictor, false otherwise.
|
||||
*/
|
||||
virtual void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model,
|
||||
float missing, PredictionCacheEntry *out_preds,
|
||||
uint32_t tree_begin = 0, uint32_t tree_end = 0) const = 0;
|
||||
virtual bool InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
||||
const gbm::GBTreeModel &model, float missing,
|
||||
PredictionCacheEntry *out_preds,
|
||||
uint32_t tree_begin = 0,
|
||||
uint32_t tree_end = 0) const = 0;
|
||||
/**
|
||||
* \brief online prediction function, predict score for one instance at a time
|
||||
* NOTE: use the batch prediction interface if possible, batch prediction is
|
||||
|
||||
@ -58,21 +58,23 @@ CallbackEnv = collections.namedtuple(
|
||||
"evaluation_result_list"])
|
||||
|
||||
|
||||
def from_pystr_to_cstr(data):
|
||||
"""Convert a list of Python str to C pointer
|
||||
def from_pystr_to_cstr(data: Union[str, List[str]]):
|
||||
"""Convert a Python str or list of Python str to C pointer
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : list
|
||||
list of str
|
||||
data
|
||||
str or list of str
|
||||
"""
|
||||
|
||||
if not isinstance(data, list):
|
||||
raise NotImplementedError
|
||||
if isinstance(data, str):
|
||||
return bytes(data, "utf-8")
|
||||
if isinstance(data, list):
|
||||
pointers = (ctypes.c_char_p * len(data))()
|
||||
data = [bytes(d, 'utf-8') for d in data]
|
||||
pointers[:] = data
|
||||
return pointers
|
||||
raise TypeError()
|
||||
|
||||
|
||||
def from_cstr_to_pystr(data, length):
|
||||
@ -190,21 +192,40 @@ def _check_call(ret):
|
||||
raise XGBoostError(py_str(_LIB.XGBGetLastError()))
|
||||
|
||||
|
||||
def ctypes2numpy(cptr, length, dtype) -> np.ndarray:
|
||||
"""Convert a ctypes pointer array to a numpy array."""
|
||||
NUMPY_TO_CTYPES_MAPPING = {
|
||||
def _numpy2ctypes_type(dtype):
|
||||
_NUMPY_TO_CTYPES_MAPPING = {
|
||||
np.float32: ctypes.c_float,
|
||||
np.float64: ctypes.c_double,
|
||||
np.uint32: ctypes.c_uint,
|
||||
np.uint64: ctypes.c_uint64,
|
||||
np.int32: ctypes.c_int32,
|
||||
np.int64: ctypes.c_int64,
|
||||
}
|
||||
if dtype not in NUMPY_TO_CTYPES_MAPPING:
|
||||
raise RuntimeError('Supported types: {}'.format(
|
||||
NUMPY_TO_CTYPES_MAPPING.keys()))
|
||||
ctype = NUMPY_TO_CTYPES_MAPPING[dtype]
|
||||
if np.intc is not np.int32: # Windows
|
||||
_NUMPY_TO_CTYPES_MAPPING[np.intc] = _NUMPY_TO_CTYPES_MAPPING[np.int32]
|
||||
if dtype not in _NUMPY_TO_CTYPES_MAPPING.keys():
|
||||
raise TypeError(
|
||||
f"Supported types: {_NUMPY_TO_CTYPES_MAPPING.keys()}, got: {dtype}"
|
||||
)
|
||||
return _NUMPY_TO_CTYPES_MAPPING[dtype]
|
||||
|
||||
|
||||
def _array_interface(data: np.ndarray) -> bytes:
|
||||
interface = data.__array_interface__
|
||||
if "mask" in interface:
|
||||
interface["mask"] = interface["mask"].__array_interface__
|
||||
interface_str = bytes(json.dumps(interface, indent=2), "utf-8")
|
||||
return interface_str
|
||||
|
||||
|
||||
def ctypes2numpy(cptr, length, dtype):
|
||||
"""Convert a ctypes pointer array to a numpy array."""
|
||||
ctype = _numpy2ctypes_type(dtype)
|
||||
if not isinstance(cptr, ctypes.POINTER(ctype)):
|
||||
raise RuntimeError('expected {} pointer'.format(ctype))
|
||||
raise RuntimeError("expected {} pointer".format(ctype))
|
||||
res = np.zeros(length, dtype=dtype)
|
||||
if not ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]):
|
||||
raise RuntimeError('memmove failed')
|
||||
raise RuntimeError("memmove failed")
|
||||
return res
|
||||
|
||||
|
||||
@ -214,25 +235,21 @@ def ctypes2cupy(cptr, length, dtype):
|
||||
import cupy
|
||||
from cupy.cuda.memory import MemoryPointer
|
||||
from cupy.cuda.memory import UnownedMemory
|
||||
CUPY_TO_CTYPES_MAPPING = {
|
||||
cupy.float32: ctypes.c_float,
|
||||
cupy.uint32: ctypes.c_uint
|
||||
}
|
||||
|
||||
CUPY_TO_CTYPES_MAPPING = {cupy.float32: ctypes.c_float, cupy.uint32: ctypes.c_uint}
|
||||
if dtype not in CUPY_TO_CTYPES_MAPPING.keys():
|
||||
raise RuntimeError('Supported types: {}'.format(
|
||||
CUPY_TO_CTYPES_MAPPING.keys()
|
||||
))
|
||||
raise RuntimeError("Supported types: {}".format(CUPY_TO_CTYPES_MAPPING.keys()))
|
||||
addr = ctypes.cast(cptr, ctypes.c_void_p).value
|
||||
# pylint: disable=c-extension-no-member,no-member
|
||||
device = cupy.cuda.runtime.pointerGetAttributes(addr).device
|
||||
# The owner field is just used to keep the memory alive with ref count. As
|
||||
# unowned's life time is scoped within this function we don't need that.
|
||||
unownd = UnownedMemory(
|
||||
addr, length.value * ctypes.sizeof(CUPY_TO_CTYPES_MAPPING[dtype]),
|
||||
owner=None)
|
||||
addr, length * ctypes.sizeof(CUPY_TO_CTYPES_MAPPING[dtype]), owner=None
|
||||
)
|
||||
memptr = MemoryPointer(unownd, 0)
|
||||
# pylint: disable=unexpected-keyword-arg
|
||||
mem = cupy.ndarray((length.value, ), dtype=dtype, memptr=memptr)
|
||||
mem = cupy.ndarray((length,), dtype=dtype, memptr=memptr)
|
||||
assert mem.device.id == device
|
||||
arr = cupy.array(mem, copy=True)
|
||||
return arr
|
||||
@ -256,28 +273,29 @@ def c_str(string):
|
||||
|
||||
def c_array(ctype, values):
|
||||
"""Convert a python string to c array."""
|
||||
if (isinstance(values, np.ndarray)
|
||||
and values.dtype.itemsize == ctypes.sizeof(ctype)):
|
||||
if isinstance(values, np.ndarray) and values.dtype.itemsize == ctypes.sizeof(ctype):
|
||||
return (ctype * len(values)).from_buffer_copy(values)
|
||||
return (ctype * len(values))(*values)
|
||||
|
||||
|
||||
def _prediction_output(shape, dims, predts, is_cuda):
|
||||
arr_shape: np.ndarray = ctypes2numpy(shape, dims.value, np.uint64)
|
||||
length = int(np.prod(arr_shape))
|
||||
if is_cuda:
|
||||
arr_predict = ctypes2cupy(predts, length, np.float32)
|
||||
else:
|
||||
arr_predict: np.ndarray = ctypes2numpy(predts, length, np.float32)
|
||||
arr_predict = arr_predict.reshape(arr_shape)
|
||||
return arr_predict
|
||||
|
||||
|
||||
class DataIter:
|
||||
'''The interface for user defined data iterator. Currently is only
|
||||
supported by Device DMatrix.
|
||||
'''The interface for user defined data iterator. Currently is only supported by Device
|
||||
DMatrix.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
rows : int
|
||||
Total number of rows combining all batches.
|
||||
cols : int
|
||||
Number of columns for each batch.
|
||||
'''
|
||||
def __init__(self):
|
||||
proxy_handle = ctypes.c_void_p()
|
||||
_check_call(_LIB.XGProxyDMatrixCreate(ctypes.byref(proxy_handle)))
|
||||
self._handle = DeviceQuantileDMatrix(proxy_handle)
|
||||
self._handle = _ProxyDMatrix()
|
||||
self.exception = None
|
||||
|
||||
@property
|
||||
@ -300,12 +318,7 @@ class DataIter:
|
||||
if self.exception is not None:
|
||||
return 0
|
||||
|
||||
def data_handle(data, label=None, weight=None, base_margin=None,
|
||||
group=None,
|
||||
qid=None,
|
||||
label_lower_bound=None, label_upper_bound=None,
|
||||
feature_names=None, feature_types=None,
|
||||
feature_weights=None):
|
||||
def data_handle(data, feature_names=None, feature_types=None, **kwargs):
|
||||
from .data import dispatch_device_quantile_dmatrix_set_data
|
||||
from .data import _device_quantile_transform
|
||||
data, feature_names, feature_types = _device_quantile_transform(
|
||||
@ -313,16 +326,9 @@ class DataIter:
|
||||
)
|
||||
dispatch_device_quantile_dmatrix_set_data(self.proxy, data)
|
||||
self.proxy.set_info(
|
||||
label=label,
|
||||
weight=weight,
|
||||
base_margin=base_margin,
|
||||
group=group,
|
||||
qid=qid,
|
||||
label_lower_bound=label_lower_bound,
|
||||
label_upper_bound=label_upper_bound,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
feature_weights=feature_weights
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
# Differ the exception in order to return 0 and stop the iteration.
|
||||
@ -558,7 +564,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
feature_types=None,
|
||||
feature_weights=None
|
||||
) -> None:
|
||||
"""Set meta info for DMatrix. See doc string for DMatrix constructor."""
|
||||
"""Set meta info for DMatrix. See doc string for :py:obj:`xgboost.DMatrix`."""
|
||||
from .data import dispatch_meta_backend
|
||||
|
||||
if label is not None:
|
||||
@ -959,18 +965,52 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
c_bst_ulong(0)))
|
||||
|
||||
|
||||
class _ProxyDMatrix(DMatrix):
|
||||
"""A placeholder class when DMatrix cannot be constructed (DeviceQuantileDMatrix,
|
||||
inplace_predict).
|
||||
|
||||
"""
|
||||
def __init__(self): # pylint: disable=super-init-not-called
|
||||
self.handle = ctypes.c_void_p()
|
||||
_check_call(_LIB.XGProxyDMatrixCreate(ctypes.byref(self.handle)))
|
||||
|
||||
def _set_data_from_cuda_interface(self, data):
|
||||
'''Set data from CUDA array interface.'''
|
||||
interface = data.__cuda_array_interface__
|
||||
interface_str = bytes(json.dumps(interface, indent=2), 'utf-8')
|
||||
_check_call(
|
||||
_LIB.XGDeviceQuantileDMatrixSetDataCudaArrayInterface(
|
||||
self.handle,
|
||||
interface_str
|
||||
)
|
||||
)
|
||||
|
||||
def _set_data_from_cuda_columnar(self, data):
|
||||
'''Set data from CUDA columnar format.1'''
|
||||
from .data import _cudf_array_interfaces
|
||||
interfaces_str = _cudf_array_interfaces(data)
|
||||
_check_call(
|
||||
_LIB.XGDeviceQuantileDMatrixSetDataCudaColumnar(
|
||||
self.handle,
|
||||
interfaces_str
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class DeviceQuantileDMatrix(DMatrix):
|
||||
"""Device memory Data Matrix used in XGBoost for training with
|
||||
tree_method='gpu_hist'. Do not use this for test/validation tasks as some
|
||||
information may be lost in quantisation. This DMatrix is primarily designed
|
||||
to save memory in training from device memory inputs by avoiding
|
||||
intermediate storage. Set max_bin to control the number of bins during
|
||||
quantisation. See doc string in `DMatrix` for documents on meta info.
|
||||
"""Device memory Data Matrix used in XGBoost for training with tree_method='gpu_hist'. Do
|
||||
not use this for test/validation tasks as some information may be lost in
|
||||
quantisation. This DMatrix is primarily designed to save memory in training from
|
||||
device memory inputs by avoiding intermediate storage. Set max_bin to control the
|
||||
number of bins during quantisation. See doc string in :py:obj:`xgboost.DMatrix` for
|
||||
documents on meta info.
|
||||
|
||||
You can construct DeviceQuantileDMatrix from cupy/cudf/dlpack.
|
||||
|
||||
.. versionadded:: 1.1.0
|
||||
|
||||
"""
|
||||
|
||||
@_deprecate_positional_args
|
||||
def __init__( # pylint: disable=super-init-not-called
|
||||
self,
|
||||
@ -1000,58 +1040,72 @@ class DeviceQuantileDMatrix(DMatrix):
|
||||
if isinstance(data, ctypes.c_void_p):
|
||||
self.handle = data
|
||||
return
|
||||
from .data import init_device_quantile_dmatrix
|
||||
handle, feature_names, feature_types = init_device_quantile_dmatrix(
|
||||
data,
|
||||
label=label, weight=weight,
|
||||
base_margin=base_margin,
|
||||
group=group,
|
||||
qid=qid,
|
||||
missing=self.missing,
|
||||
label_lower_bound=label_lower_bound,
|
||||
label_upper_bound=label_upper_bound,
|
||||
feature_weights=feature_weights,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
threads=self.nthread,
|
||||
max_bin=self.max_bin,
|
||||
)
|
||||
|
||||
if enable_categorical:
|
||||
raise NotImplementedError(
|
||||
'categorical support is not enabled on DeviceQuantileDMatrix.'
|
||||
)
|
||||
self.handle = handle
|
||||
if qid is not None and group is not None:
|
||||
raise ValueError(
|
||||
'Only one of the eval_qid or eval_group for each evaluation '
|
||||
'dataset should be provided.'
|
||||
)
|
||||
|
||||
self.feature_names = feature_names
|
||||
self.feature_types = feature_types
|
||||
|
||||
def _set_data_from_cuda_interface(self, data):
|
||||
'''Set data from CUDA array interface.'''
|
||||
interface = data.__cuda_array_interface__
|
||||
interface_str = bytes(json.dumps(interface, indent=2), 'utf-8')
|
||||
_check_call(
|
||||
_LIB.XGDeviceQuantileDMatrixSetDataCudaArrayInterface(
|
||||
self.handle,
|
||||
interface_str
|
||||
)
|
||||
self._init(
|
||||
data,
|
||||
label=label,
|
||||
weight=weight,
|
||||
base_margin=base_margin,
|
||||
group=group,
|
||||
qid=qid,
|
||||
label_lower_bound=label_lower_bound,
|
||||
label_upper_bound=label_upper_bound,
|
||||
feature_weights=feature_weights,
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
)
|
||||
|
||||
def _set_data_from_cuda_columnar(self, data):
|
||||
'''Set data from CUDA columnar format.1'''
|
||||
from .data import _cudf_array_interfaces
|
||||
interfaces_str = _cudf_array_interfaces(data)
|
||||
_check_call(
|
||||
_LIB.XGDeviceQuantileDMatrixSetDataCudaColumnar(
|
||||
self.handle,
|
||||
interfaces_str
|
||||
def _init(self, data, feature_names, feature_types, **meta):
|
||||
from .data import (
|
||||
_is_dlpack,
|
||||
_transform_dlpack,
|
||||
_is_iter,
|
||||
SingleBatchInternalIter,
|
||||
)
|
||||
|
||||
if _is_dlpack(data):
|
||||
# We specialize for dlpack because cupy will take the memory from it so
|
||||
# it can't be transformed twice.
|
||||
data = _transform_dlpack(data)
|
||||
if _is_iter(data):
|
||||
it = data
|
||||
else:
|
||||
it = SingleBatchInternalIter(
|
||||
data, **meta, feature_names=feature_names, feature_types=feature_types
|
||||
)
|
||||
|
||||
reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(it.reset_wrapper)
|
||||
next_callback = ctypes.CFUNCTYPE(
|
||||
ctypes.c_int,
|
||||
ctypes.c_void_p,
|
||||
)(it.next_wrapper)
|
||||
handle = ctypes.c_void_p()
|
||||
ret = _LIB.XGDeviceQuantileDMatrixCreateFromCallback(
|
||||
None,
|
||||
it.proxy.handle,
|
||||
reset_callback,
|
||||
next_callback,
|
||||
ctypes.c_float(self.missing),
|
||||
ctypes.c_int(self.nthread),
|
||||
ctypes.c_int(self.max_bin),
|
||||
ctypes.byref(handle),
|
||||
)
|
||||
if it.exception:
|
||||
raise it.exception
|
||||
# delay check_call to throw intermediate exception first
|
||||
_check_call(ret)
|
||||
self.handle = handle
|
||||
|
||||
|
||||
Objective = Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]]
|
||||
Metric = Callable[[np.ndarray, DMatrix], Tuple[str, float]]
|
||||
@ -1346,7 +1400,7 @@ class Booster(object):
|
||||
|
||||
def boost(self, dtrain, grad, hess):
|
||||
"""Boost the booster for one iteration, with customized gradient
|
||||
statistics. Like :func:`xgboost.core.Booster.update`, this
|
||||
statistics. Like :py:func:`xgboost.Booster.update`, this
|
||||
function should not be called directly by users.
|
||||
|
||||
Parameters
|
||||
@ -1360,7 +1414,9 @@ class Booster(object):
|
||||
|
||||
"""
|
||||
if len(grad) != len(hess):
|
||||
raise ValueError('grad / hess length mismatch: {} / {}'.format(len(grad), len(hess)))
|
||||
raise ValueError(
|
||||
'grad / hess length mismatch: {} / {}'.format(len(grad), len(hess))
|
||||
)
|
||||
if not isinstance(dtrain, DMatrix):
|
||||
raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__))
|
||||
self._validate_features(dtrain)
|
||||
@ -1453,17 +1509,12 @@ class Booster(object):
|
||||
training=False):
|
||||
"""Predict with data.
|
||||
|
||||
.. note:: This function is not thread safe except for ``gbtree``
|
||||
booster.
|
||||
.. note:: This function is not thread safe except for ``gbtree`` booster.
|
||||
|
||||
For ``gbtree`` booster, the thread safety is guaranteed by locks.
|
||||
For lock free prediction use ``inplace_predict`` instead. Also, the
|
||||
safety does not hold when used in conjunction with other methods.
|
||||
|
||||
When using booster other than ``gbtree``, predict can only be called
|
||||
from one thread. If you want to run prediction using multiple
|
||||
thread, call ``bst.copy()`` to make copies of model object and then
|
||||
call ``predict()``.
|
||||
When using booster other than ``gbtree``, predict can only be called from one
|
||||
thread. If you want to run prediction using multiple thread, call
|
||||
:py:meth:`xgboost.Booster.copy` to make copies of model object and then call
|
||||
``predict()``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -1579,9 +1630,17 @@ class Booster(object):
|
||||
preds = preds.reshape(nrow, chunk_size)
|
||||
return preds
|
||||
|
||||
def inplace_predict(self, data, iteration_range=(0, 0),
|
||||
predict_type='value', missing=np.nan):
|
||||
'''Run prediction in-place, Unlike ``predict`` method, inplace prediction does
|
||||
def inplace_predict(
|
||||
self,
|
||||
data,
|
||||
iteration_range: Tuple[int, int] = (0, 0),
|
||||
predict_type: str = "value",
|
||||
missing: float = np.nan,
|
||||
validate_features: bool = True,
|
||||
base_margin: Any = None,
|
||||
strict_shape: bool = False
|
||||
):
|
||||
"""Run prediction in-place, Unlike ``predict`` method, inplace prediction does
|
||||
not cache the prediction result.
|
||||
|
||||
Calling only ``inplace_predict`` in multiple threads is safe and lock
|
||||
@ -1617,6 +1676,15 @@ class Booster(object):
|
||||
missing : float
|
||||
Value in the input data which needs to be present as a missing
|
||||
value.
|
||||
validate_features:
|
||||
See :py:meth:`xgboost.Booster.predict` for details.
|
||||
base_margin:
|
||||
See :py:obj:`xgboost.DMatrix` for details.
|
||||
strict_shape:
|
||||
When set to True, output shape is invariant to whether classification is used.
|
||||
For both value and margin prediction, the output shape is (n_samples,
|
||||
n_groups), n_groups == 1 when multi-class is not used. Default to False, in
|
||||
which case the output shape can be (n_samples, ) if multi-class is not used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -1624,107 +1692,117 @@ class Booster(object):
|
||||
The prediction result. When input data is on GPU, prediction
|
||||
result is stored in a cupy array.
|
||||
|
||||
'''
|
||||
|
||||
def reshape_output(predt, rows):
|
||||
'''Reshape for multi-output prediction.'''
|
||||
if predt.size != rows and predt.size % rows == 0:
|
||||
cols = int(predt.size / rows)
|
||||
predt = predt.reshape(rows, cols)
|
||||
return predt
|
||||
return predt
|
||||
|
||||
length = c_bst_ulong()
|
||||
"""
|
||||
preds = ctypes.POINTER(ctypes.c_float)()
|
||||
iteration_range = (ctypes.c_uint(iteration_range[0]),
|
||||
ctypes.c_uint(iteration_range[1]))
|
||||
|
||||
# once caching is supported, we can pass id(data) as cache id.
|
||||
try:
|
||||
import pandas as pd
|
||||
|
||||
if isinstance(data, pd.DataFrame):
|
||||
data = data.values
|
||||
except ImportError:
|
||||
pass
|
||||
args = {
|
||||
"type": 0,
|
||||
"training": False,
|
||||
"iteration_begin": iteration_range[0],
|
||||
"iteration_end": iteration_range[1],
|
||||
"missing": missing,
|
||||
"strict_shape": strict_shape,
|
||||
"cache_id": 0,
|
||||
}
|
||||
if predict_type == "margin":
|
||||
args["type"] = 1
|
||||
shape = ctypes.POINTER(c_bst_ulong)()
|
||||
dims = c_bst_ulong()
|
||||
|
||||
if base_margin is not None:
|
||||
proxy = _ProxyDMatrix()
|
||||
proxy.set_info(base_margin=base_margin)
|
||||
p_handle = proxy.handle
|
||||
else:
|
||||
proxy = None
|
||||
p_handle = ctypes.c_void_p()
|
||||
assert proxy is None or isinstance(proxy, _ProxyDMatrix)
|
||||
if validate_features:
|
||||
if len(data.shape) != 1 and self.num_features() != data.shape[1]:
|
||||
raise ValueError(
|
||||
f"Feature shape mismatch, expected: {self.num_features()}, "
|
||||
f"got {data.shape[0]}"
|
||||
)
|
||||
|
||||
if isinstance(data, np.ndarray):
|
||||
assert data.flags.c_contiguous
|
||||
arr = np.array(data.reshape(data.size), copy=False,
|
||||
dtype=np.float32)
|
||||
_check_call(_LIB.XGBoosterPredictFromDense(
|
||||
from .data import _maybe_np_slice
|
||||
data = _maybe_np_slice(data, data.dtype)
|
||||
_check_call(
|
||||
_LIB.XGBoosterPredictFromDense(
|
||||
self.handle,
|
||||
arr.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
|
||||
c_bst_ulong(data.shape[0]),
|
||||
c_bst_ulong(data.shape[1]),
|
||||
ctypes.c_float(missing),
|
||||
iteration_range[0],
|
||||
iteration_range[1],
|
||||
c_str(predict_type),
|
||||
c_bst_ulong(0),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(preds)
|
||||
))
|
||||
preds = ctypes2numpy(preds, length.value, np.float32)
|
||||
rows = data.shape[0]
|
||||
return reshape_output(preds, rows)
|
||||
_array_interface(data),
|
||||
from_pystr_to_cstr(json.dumps(args)),
|
||||
p_handle,
|
||||
ctypes.byref(shape),
|
||||
ctypes.byref(dims),
|
||||
ctypes.byref(preds),
|
||||
)
|
||||
)
|
||||
return _prediction_output(shape, dims, preds, False)
|
||||
if isinstance(data, scipy.sparse.csr_matrix):
|
||||
csr = data
|
||||
_check_call(_LIB.XGBoosterPredictFromCSR(
|
||||
_check_call(
|
||||
_LIB.XGBoosterPredictFromCSR(
|
||||
self.handle,
|
||||
c_array(ctypes.c_size_t, csr.indptr),
|
||||
c_array(ctypes.c_uint, csr.indices),
|
||||
c_array(ctypes.c_float, csr.data),
|
||||
ctypes.c_size_t(len(csr.indptr)),
|
||||
ctypes.c_size_t(len(csr.data)),
|
||||
_array_interface(csr.indptr),
|
||||
_array_interface(csr.indices),
|
||||
_array_interface(csr.data),
|
||||
ctypes.c_size_t(csr.shape[1]),
|
||||
ctypes.c_float(missing),
|
||||
iteration_range[0],
|
||||
iteration_range[1],
|
||||
c_str(predict_type),
|
||||
c_bst_ulong(0),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(preds)))
|
||||
preds = ctypes2numpy(preds, length.value, np.float32)
|
||||
rows = data.shape[0]
|
||||
return reshape_output(preds, rows)
|
||||
if lazy_isinstance(data, 'cupy.core.core', 'ndarray'):
|
||||
assert data.flags.c_contiguous
|
||||
from_pystr_to_cstr(json.dumps(args)),
|
||||
p_handle,
|
||||
ctypes.byref(shape),
|
||||
ctypes.byref(dims),
|
||||
ctypes.byref(preds),
|
||||
)
|
||||
)
|
||||
return _prediction_output(shape, dims, preds, False)
|
||||
if lazy_isinstance(data, "cupy.core.core", "ndarray"):
|
||||
from .data import _transform_cupy_array
|
||||
data = _transform_cupy_array(data)
|
||||
interface = data.__cuda_array_interface__
|
||||
if 'mask' in interface:
|
||||
interface['mask'] = interface['mask'].__cuda_array_interface__
|
||||
interface_str = bytes(json.dumps(interface, indent=2), 'utf-8')
|
||||
_check_call(_LIB.XGBoosterPredictFromArrayInterface(
|
||||
if "mask" in interface:
|
||||
interface["mask"] = interface["mask"].__cuda_array_interface__
|
||||
interface_str = bytes(json.dumps(interface, indent=2), "utf-8")
|
||||
_check_call(
|
||||
_LIB.XGBoosterPredictFromArrayInterface(
|
||||
self.handle,
|
||||
interface_str,
|
||||
ctypes.c_float(missing),
|
||||
iteration_range[0],
|
||||
iteration_range[1],
|
||||
c_str(predict_type),
|
||||
c_bst_ulong(0),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(preds)))
|
||||
mem = ctypes2cupy(preds, length, np.float32)
|
||||
rows = data.shape[0]
|
||||
return reshape_output(mem, rows)
|
||||
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
|
||||
from_pystr_to_cstr(json.dumps(args)),
|
||||
p_handle,
|
||||
ctypes.byref(shape),
|
||||
ctypes.byref(dims),
|
||||
ctypes.byref(preds),
|
||||
)
|
||||
)
|
||||
return _prediction_output(shape, dims, preds, True)
|
||||
if lazy_isinstance(data, "cudf.core.dataframe", "DataFrame"):
|
||||
from .data import _cudf_array_interfaces
|
||||
|
||||
interfaces_str = _cudf_array_interfaces(data)
|
||||
_check_call(_LIB.XGBoosterPredictFromArrayInterfaceColumns(
|
||||
_check_call(
|
||||
_LIB.XGBoosterPredictFromArrayInterfaceColumns(
|
||||
self.handle,
|
||||
interfaces_str,
|
||||
ctypes.c_float(missing),
|
||||
iteration_range[0],
|
||||
iteration_range[1],
|
||||
c_str(predict_type),
|
||||
c_bst_ulong(0),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(preds)))
|
||||
mem = ctypes2cupy(preds, length, np.float32)
|
||||
rows = data.shape[0]
|
||||
predt = reshape_output(mem, rows)
|
||||
return predt
|
||||
from_pystr_to_cstr(json.dumps(args)),
|
||||
p_handle,
|
||||
ctypes.byref(shape),
|
||||
ctypes.byref(dims),
|
||||
ctypes.byref(preds),
|
||||
)
|
||||
)
|
||||
return _prediction_output(shape, dims, preds, True)
|
||||
|
||||
raise TypeError('Data type:' + str(type(data)) +
|
||||
' not supported by inplace prediction.')
|
||||
raise TypeError(
|
||||
"Data type:" + str(type(data)) + " not supported by inplace prediction."
|
||||
)
|
||||
|
||||
def save_model(self, fname):
|
||||
"""Save the model to a file.
|
||||
|
||||
@ -187,8 +187,8 @@ class DaskDMatrix:
|
||||
`DaskDMatrix` forces all lazy computation to be carried out. Wait for the input data
|
||||
explicitly if you want to see actual computation of constructing `DaskDMatrix`.
|
||||
|
||||
See doc string for DMatrix constructor for other parameters. DaskDMatrix accepts only
|
||||
dask collection.
|
||||
See doc for :py:obj:`xgboost.DMatrix` constructor for other parameters. DaskDMatrix
|
||||
accepts only dask collection.
|
||||
|
||||
.. note::
|
||||
|
||||
@ -575,7 +575,8 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
||||
memory usage by eliminating data copies. Internally the all partitions/chunks of data
|
||||
are merged by weighted GK sketching. So the number of partitions from dask may affect
|
||||
training accuracy as GK generates bounded error for each merge. See doc string for
|
||||
`DeviceQuantileDMatrix` and `DMatrix` for other parameters.
|
||||
:py:obj:`xgboost.DeviceQuantileDMatrix` and :py:obj:`xgboost.DMatrix` for other
|
||||
parameters.
|
||||
|
||||
.. versionadded:: 1.2.0
|
||||
|
||||
|
||||
@ -5,11 +5,12 @@ import ctypes
|
||||
import json
|
||||
import warnings
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .core import c_array, _LIB, _check_call, c_str
|
||||
from .core import DataIter, DeviceQuantileDMatrix, DMatrix
|
||||
from .core import DataIter, _ProxyDMatrix, DMatrix
|
||||
from .compat import lazy_isinstance
|
||||
|
||||
c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name
|
||||
@ -113,7 +114,7 @@ def _maybe_np_slice(data, dtype):
|
||||
return data
|
||||
|
||||
|
||||
def _transform_np_array(data: np.ndarray):
|
||||
def _transform_np_array(data: np.ndarray) -> np.ndarray:
|
||||
if not isinstance(data, np.ndarray) and hasattr(data, '__array__'):
|
||||
data = np.array(data, copy=False)
|
||||
if len(data.shape) != 2:
|
||||
@ -142,7 +143,7 @@ def _from_numpy_array(data, missing, nthread, feature_names, feature_types):
|
||||
input layout and type if memory use is a concern.
|
||||
|
||||
"""
|
||||
flatten = _transform_np_array(data)
|
||||
flatten: np.ndarray = _transform_np_array(data)
|
||||
handle = ctypes.c_void_p()
|
||||
_check_call(_LIB.XGDMatrixCreateFromMat_omp(
|
||||
flatten.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
|
||||
@ -783,54 +784,6 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
|
||||
self.it = 0
|
||||
|
||||
|
||||
def init_device_quantile_dmatrix(
|
||||
data, missing, max_bin, threads, feature_names, feature_types, **meta
|
||||
):
|
||||
'''Constructor for DeviceQuantileDMatrix.'''
|
||||
if not any([_is_cudf_df(data), _is_cudf_ser(data), _is_cupy_array(data),
|
||||
_is_dlpack(data), _is_iter(data)]):
|
||||
raise TypeError(str(type(data)) +
|
||||
' is not supported for DeviceQuantileDMatrix')
|
||||
if _is_dlpack(data):
|
||||
# We specialize for dlpack because cupy will take the memory from it so
|
||||
# it can't be transformed twice.
|
||||
data = _transform_dlpack(data)
|
||||
if _is_iter(data):
|
||||
it = data
|
||||
else:
|
||||
it = SingleBatchInternalIter(
|
||||
data, **meta, feature_names=feature_names,
|
||||
feature_types=feature_types)
|
||||
|
||||
reset_factory = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
|
||||
reset_callback = reset_factory(it.reset_wrapper)
|
||||
next_factory = ctypes.CFUNCTYPE(
|
||||
ctypes.c_int,
|
||||
ctypes.c_void_p,
|
||||
)
|
||||
next_callback = next_factory(it.next_wrapper)
|
||||
handle = ctypes.c_void_p()
|
||||
ret = _LIB.XGDeviceQuantileDMatrixCreateFromCallback(
|
||||
None,
|
||||
it.proxy.handle,
|
||||
reset_callback,
|
||||
next_callback,
|
||||
ctypes.c_float(missing),
|
||||
ctypes.c_int(threads),
|
||||
ctypes.c_int(max_bin),
|
||||
ctypes.byref(handle)
|
||||
)
|
||||
if it.exception:
|
||||
raise it.exception
|
||||
# delay check_call to throw intermediate exception first
|
||||
_check_call(ret)
|
||||
matrix = DeviceQuantileDMatrix(handle)
|
||||
feature_names = matrix.feature_names
|
||||
feature_types = matrix.feature_types
|
||||
matrix.handle = None
|
||||
return handle, feature_names, feature_types
|
||||
|
||||
|
||||
def _device_quantile_transform(data, feature_names, feature_types):
|
||||
if _is_cudf_df(data):
|
||||
return _transform_cudf_df(data, feature_names, feature_types)
|
||||
@ -845,7 +798,7 @@ def _device_quantile_transform(data, feature_names, feature_types):
|
||||
str(type(data)))
|
||||
|
||||
|
||||
def dispatch_device_quantile_dmatrix_set_data(proxy, data):
|
||||
def dispatch_device_quantile_dmatrix_set_data(proxy: _ProxyDMatrix, data: Any) -> None:
|
||||
'''Dispatch for DeviceQuantileDMatrix.'''
|
||||
if _is_cudf_df(data):
|
||||
proxy._set_data_from_cuda_columnar(data) # pylint: disable=W0212
|
||||
|
||||
@ -21,6 +21,7 @@
|
||||
#include "xgboost/global_config.h"
|
||||
|
||||
#include "c_api_error.h"
|
||||
#include "c_api_utils.h"
|
||||
#include "../common/io.h"
|
||||
#include "../common/charconv.h"
|
||||
#include "../data/adapter.h"
|
||||
@ -617,89 +618,91 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
||||
API_END();
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
void InplacePredictImpl(std::shared_ptr<T> x, std::shared_ptr<DMatrix> p_m,
|
||||
char const *c_json_config, Learner *learner,
|
||||
size_t n_rows, size_t n_cols,
|
||||
xgboost::bst_ulong const **out_shape,
|
||||
xgboost::bst_ulong *out_dim, const float **out_result) {
|
||||
auto config = Json::Load(StringView{c_json_config});
|
||||
CHECK_EQ(get<Integer const>(config["cache_id"]), 0) << "Cache ID is not supported yet";
|
||||
|
||||
HostDeviceVector<float>* p_predt { nullptr };
|
||||
auto type = PredictionType(get<Integer const>(config["type"]));
|
||||
learner->InplacePredict(x, p_m, type, get<Number const>(config["missing"]),
|
||||
&p_predt,
|
||||
get<Integer const>(config["iteration_begin"]),
|
||||
get<Integer const>(config["iteration_end"]));
|
||||
CHECK(p_predt);
|
||||
auto &shape = learner->GetThreadLocal().prediction_shape;
|
||||
auto chunksize = n_rows == 0 ? 0 : p_predt->Size() / n_rows;
|
||||
bool strict_shape = get<Boolean const>(config["strict_shape"]);
|
||||
CalcPredictShape(strict_shape, type, n_rows, n_cols, chunksize, learner->Groups(),
|
||||
learner->BoostedRounds(), &shape, out_dim);
|
||||
*out_result = dmlc::BeginPtr(p_predt->HostVector());
|
||||
*out_shape = dmlc::BeginPtr(shape);
|
||||
}
|
||||
|
||||
// A hidden API as cache id is not being supported yet.
|
||||
XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, float *values,
|
||||
xgboost::bst_ulong n_rows,
|
||||
xgboost::bst_ulong n_cols,
|
||||
float missing,
|
||||
unsigned iteration_begin,
|
||||
unsigned iteration_end,
|
||||
char const* c_type,
|
||||
xgboost::bst_ulong cache_id,
|
||||
xgboost::bst_ulong *out_len,
|
||||
XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle,
|
||||
char const *array_interface,
|
||||
char const *c_json_config,
|
||||
DMatrixHandle m,
|
||||
xgboost::bst_ulong const **out_shape,
|
||||
xgboost::bst_ulong *out_dim,
|
||||
const float **out_result) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet";
|
||||
std::shared_ptr<xgboost::data::ArrayAdapter> x{
|
||||
new xgboost::data::ArrayAdapter(StringView{array_interface})};
|
||||
std::shared_ptr<DMatrix> p_m {nullptr};
|
||||
if (m) {
|
||||
p_m = *static_cast<std::shared_ptr<DMatrix> *>(m);
|
||||
}
|
||||
auto *learner = static_cast<xgboost::Learner *>(handle);
|
||||
|
||||
std::shared_ptr<xgboost::data::DenseAdapter> x{
|
||||
new xgboost::data::DenseAdapter(values, n_rows, n_cols)};
|
||||
HostDeviceVector<float>* p_predt { nullptr };
|
||||
std::string type { c_type };
|
||||
learner->InplacePredict(x, type, missing, &p_predt, iteration_begin, iteration_end);
|
||||
CHECK(p_predt);
|
||||
|
||||
*out_result = dmlc::BeginPtr(p_predt->HostVector());
|
||||
*out_len = static_cast<xgboost::bst_ulong>(p_predt->Size());
|
||||
InplacePredictImpl(x, p_m, c_json_config, learner, x->NumRows(),
|
||||
x->NumColumns(), out_shape, out_dim, out_result);
|
||||
API_END();
|
||||
}
|
||||
|
||||
// A hidden API as cache id is not being supported yet.
|
||||
XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle,
|
||||
const size_t* indptr,
|
||||
const unsigned* indices,
|
||||
const bst_float* data,
|
||||
size_t nindptr,
|
||||
size_t nelem,
|
||||
size_t num_col,
|
||||
float missing,
|
||||
unsigned iteration_begin,
|
||||
unsigned iteration_end,
|
||||
char const *c_type,
|
||||
xgboost::bst_ulong cache_id,
|
||||
xgboost::bst_ulong *out_len,
|
||||
XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, char const *indptr,
|
||||
char const *indices, char const *data,
|
||||
xgboost::bst_ulong cols,
|
||||
char const *c_json_config, DMatrixHandle m,
|
||||
xgboost::bst_ulong const **out_shape,
|
||||
xgboost::bst_ulong *out_dim,
|
||||
const float **out_result) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet";
|
||||
std::shared_ptr<xgboost::data::CSRArrayAdapter> x{
|
||||
new xgboost::data::CSRArrayAdapter{
|
||||
StringView{indptr}, StringView{indices}, StringView{data}, cols}};
|
||||
std::shared_ptr<DMatrix> p_m {nullptr};
|
||||
if (m) {
|
||||
p_m = *static_cast<std::shared_ptr<DMatrix> *>(m);
|
||||
}
|
||||
auto *learner = static_cast<xgboost::Learner *>(handle);
|
||||
|
||||
std::shared_ptr<xgboost::data::CSRAdapter> x{
|
||||
new xgboost::data::CSRAdapter(indptr, indices, data, nindptr - 1, nelem, num_col)};
|
||||
HostDeviceVector<float>* p_predt { nullptr };
|
||||
std::string type { c_type };
|
||||
learner->InplacePredict(x, type, missing, &p_predt, iteration_begin, iteration_end);
|
||||
CHECK(p_predt);
|
||||
|
||||
*out_result = dmlc::BeginPtr(p_predt->HostVector());
|
||||
*out_len = static_cast<xgboost::bst_ulong>(p_predt->Size());
|
||||
InplacePredictImpl(x, p_m, c_json_config, learner, x->NumRows(),
|
||||
x->NumColumns(), out_shape, out_dim, out_result);
|
||||
API_END();
|
||||
}
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(BoosterHandle handle,
|
||||
char const* c_json_strs,
|
||||
float missing,
|
||||
unsigned iteration_begin,
|
||||
unsigned iteration_end,
|
||||
char const* c_type,
|
||||
xgboost::bst_ulong cache_id,
|
||||
xgboost::bst_ulong *out_len,
|
||||
float const** out_result) {
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterface(
|
||||
BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
|
||||
DMatrixHandle m, xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim,
|
||||
const float **out_result) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
common::AssertGPUSupport();
|
||||
API_END();
|
||||
}
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterface(BoosterHandle handle,
|
||||
char const* c_json_strs,
|
||||
float missing,
|
||||
unsigned iteration_begin,
|
||||
unsigned iteration_end,
|
||||
char const* c_type,
|
||||
xgboost::bst_ulong cache_id,
|
||||
xgboost::bst_ulong *out_len,
|
||||
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(
|
||||
BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
|
||||
DMatrixHandle m, xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim,
|
||||
const float **out_result) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
// Copyright (c) 2019-2020 by Contributors
|
||||
// Copyright (c) 2019-2021 by Contributors
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/c_api.h"
|
||||
#include "xgboost/learner.h"
|
||||
#include "c_api_error.h"
|
||||
#include "c_api_utils.h"
|
||||
#include "../data/device_adapter.cuh"
|
||||
|
||||
using namespace xgboost; // NOLINT
|
||||
@ -30,59 +31,63 @@ XGB_DLL int XGDMatrixCreateFromArrayInterface(char const* c_json_strs,
|
||||
API_END();
|
||||
}
|
||||
|
||||
// A hidden API as cache id is not being supported yet.
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(BoosterHandle handle,
|
||||
char const* c_json_strs,
|
||||
float missing,
|
||||
unsigned iteration_begin,
|
||||
unsigned iteration_end,
|
||||
char const* c_type,
|
||||
xgboost::bst_ulong cache_id,
|
||||
xgboost::bst_ulong *out_len,
|
||||
float const** out_result) {
|
||||
template <typename T>
|
||||
int InplacePreidctCuda(BoosterHandle handle, char const *c_json_strs,
|
||||
char const *c_json_config,
|
||||
std::shared_ptr<DMatrix> p_m,
|
||||
xgboost::bst_ulong const **out_shape,
|
||||
xgboost::bst_ulong *out_dim, const float **out_result) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet";
|
||||
auto config = Json::Load(StringView{c_json_config});
|
||||
CHECK_EQ(get<Integer const>(config["cache_id"]), 0)
|
||||
<< "Cache ID is not supported yet";
|
||||
auto *learner = static_cast<Learner *>(handle);
|
||||
|
||||
std::string json_str{c_json_strs};
|
||||
auto x = std::make_shared<data::CudfAdapter>(json_str);
|
||||
auto x = std::make_shared<T>(json_str);
|
||||
HostDeviceVector<float> *p_predt{nullptr};
|
||||
std::string type { c_type };
|
||||
learner->InplacePredict(x, type, missing, &p_predt, iteration_begin, iteration_end);
|
||||
auto type = PredictionType(get<Integer const>(config["type"]));
|
||||
learner->InplacePredict(x, p_m, type, get<Number const>(config["missing"]),
|
||||
&p_predt,
|
||||
get<Integer const>(config["iteration_begin"]),
|
||||
get<Integer const>(config["iteration_end"]));
|
||||
CHECK(p_predt);
|
||||
CHECK(p_predt->DeviceCanRead());
|
||||
CHECK(p_predt->DeviceCanRead() && !p_predt->HostCanRead());
|
||||
|
||||
auto &shape = learner->GetThreadLocal().prediction_shape;
|
||||
auto chunksize = x->NumRows() == 0 ? 0 : p_predt->Size() / x->NumRows();
|
||||
bool strict_shape = get<Boolean const>(config["strict_shape"]);
|
||||
CalcPredictShape(strict_shape, type, x->NumRows(), x->NumColumns(), chunksize,
|
||||
learner->Groups(), learner->BoostedRounds(), &shape,
|
||||
out_dim);
|
||||
*out_shape = dmlc::BeginPtr(shape);
|
||||
*out_result = p_predt->ConstDevicePointer();
|
||||
*out_len = static_cast<xgboost::bst_ulong>(p_predt->Size());
|
||||
|
||||
API_END();
|
||||
}
|
||||
|
||||
// A hidden API as cache id is not being supported yet.
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterface(BoosterHandle handle,
|
||||
char const* c_json_strs,
|
||||
float missing,
|
||||
unsigned iteration_begin,
|
||||
unsigned iteration_end,
|
||||
char const* c_type,
|
||||
xgboost::bst_ulong cache_id,
|
||||
xgboost::bst_ulong *out_len,
|
||||
float const** out_result) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet";
|
||||
auto *learner = static_cast<Learner*>(handle);
|
||||
|
||||
std::string json_str{c_json_strs};
|
||||
auto x = std::make_shared<data::CupyAdapter>(json_str);
|
||||
HostDeviceVector<float>* p_predt { nullptr };
|
||||
std::string type { c_type };
|
||||
learner->InplacePredict(x, type, missing, &p_predt, iteration_begin, iteration_end);
|
||||
CHECK(p_predt);
|
||||
CHECK(p_predt->DeviceCanRead());
|
||||
|
||||
*out_result = p_predt->ConstDevicePointer();
|
||||
*out_len = static_cast<xgboost::bst_ulong>(p_predt->Size());
|
||||
|
||||
API_END();
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(
|
||||
BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
|
||||
DMatrixHandle m, xgboost::bst_ulong const **out_shape,
|
||||
xgboost::bst_ulong *out_dim, const float **out_result) {
|
||||
std::shared_ptr<DMatrix> p_m {nullptr};
|
||||
if (m) {
|
||||
p_m = *static_cast<std::shared_ptr<DMatrix> *>(m);
|
||||
}
|
||||
return InplacePreidctCuda<data::CudfAdapter>(
|
||||
handle, c_json_strs, c_json_config, p_m, out_shape, out_dim, out_result);
|
||||
}
|
||||
|
||||
// A hidden API as cache id is not being supported yet.
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterface(
|
||||
BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
|
||||
DMatrixHandle m, xgboost::bst_ulong const **out_shape,
|
||||
xgboost::bst_ulong *out_dim, const float **out_result) {
|
||||
std::shared_ptr<DMatrix> p_m {nullptr};
|
||||
if (m) {
|
||||
p_m = *static_cast<std::shared_ptr<DMatrix> *>(m);
|
||||
}
|
||||
return InplacePreidctCuda<data::CupyAdapter>(
|
||||
handle, c_json_strs, c_json_config, p_m, out_shape, out_dim, out_result);
|
||||
}
|
||||
|
||||
114
src/c_api/c_api_utils.h
Normal file
114
src/c_api/c_api_utils.h
Normal file
@ -0,0 +1,114 @@
|
||||
/*!
|
||||
* Copyright (c) 2021 by XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_C_API_C_API_UTILS_H_
|
||||
#define XGBOOST_C_API_C_API_UTILS_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/learner.h"
|
||||
|
||||
namespace xgboost {
|
||||
/* \brief Determine the output shape of prediction.
|
||||
*
|
||||
* \param strict_shape Whether should we reshape the output with consideration of groups
|
||||
* and forest.
|
||||
* \param type Prediction type
|
||||
* \param rows Input samples
|
||||
* \param cols Input features
|
||||
* \param chunksize Total elements of output / rows
|
||||
* \param groups Number of output groups from Learner
|
||||
* \param rounds end_iteration - beg_iteration
|
||||
* \param out_shape Output shape
|
||||
* \param out_dim Output dimension
|
||||
*/
|
||||
inline void CalcPredictShape(bool strict_shape, PredictionType type, size_t rows, size_t cols,
|
||||
size_t chunksize, size_t groups, size_t rounds,
|
||||
std::vector<bst_ulong> *out_shape,
|
||||
xgboost::bst_ulong *out_dim) {
|
||||
auto &shape = *out_shape;
|
||||
if ((type == PredictionType::kMargin || type == PredictionType::kValue) &&
|
||||
rows != 0) {
|
||||
CHECK_EQ(chunksize, groups);
|
||||
}
|
||||
|
||||
switch (type) {
|
||||
case PredictionType::kValue:
|
||||
case PredictionType::kMargin: {
|
||||
if (chunksize == 1 && !strict_shape) {
|
||||
*out_dim = 1;
|
||||
shape.resize(*out_dim);
|
||||
shape.front() = rows;
|
||||
} else {
|
||||
*out_dim = 2;
|
||||
shape.resize(*out_dim);
|
||||
shape.front() = rows;
|
||||
shape.back() = groups;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case PredictionType::kApproxContribution:
|
||||
case PredictionType::kContribution: {
|
||||
auto groups = chunksize / (cols + 1);
|
||||
if (groups == 1 && !strict_shape) {
|
||||
*out_dim = 2;
|
||||
shape.resize(*out_dim);
|
||||
shape.front() = rows;
|
||||
shape.back() = cols + 1;
|
||||
} else {
|
||||
*out_dim = 3;
|
||||
shape.resize(*out_dim);
|
||||
shape[0] = rows;
|
||||
shape[1] = groups;
|
||||
shape[2] = cols + 1;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case PredictionType::kInteraction: {
|
||||
if (groups == 1 && !strict_shape) {
|
||||
*out_dim = 3;
|
||||
shape.resize(*out_dim);
|
||||
shape[0] = rows;
|
||||
shape[1] = cols + 1;
|
||||
shape[2] = cols + 1;
|
||||
} else {
|
||||
*out_dim = 4;
|
||||
shape.resize(*out_dim);
|
||||
shape[0] = rows;
|
||||
shape[1] = groups;
|
||||
shape[2] = cols + 1;
|
||||
shape[3] = cols + 1;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case PredictionType::kLeaf: {
|
||||
if (strict_shape) {
|
||||
shape.resize(4);
|
||||
shape[0] = rows;
|
||||
shape[1] = rounds;
|
||||
shape[2] = groups;
|
||||
auto forest = chunksize / (shape[1] * shape[2]);
|
||||
forest = std::max(static_cast<decltype(forest)>(1), forest);
|
||||
shape[3] = forest;
|
||||
*out_dim = shape.size();
|
||||
} else {
|
||||
*out_dim = 2;
|
||||
shape.resize(*out_dim);
|
||||
shape.front() = rows;
|
||||
shape.back() = chunksize;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
LOG(FATAL) << "Unknown prediction type:" << static_cast<int>(type);
|
||||
}
|
||||
}
|
||||
CHECK_EQ(
|
||||
std::accumulate(shape.cbegin(), shape.cend(), 1, std::multiplies<>{}),
|
||||
chunksize * rows);
|
||||
}
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_C_API_C_API_UTILS_H_
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright (c) 2019~2020 by Contributors
|
||||
* Copyright (c) 2019~2021 by Contributors
|
||||
* \file adapter.h
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_ADAPTER_H_
|
||||
@ -228,6 +228,128 @@ class DenseAdapter : public detail::SingleBatchDataIter<DenseAdapterBatch> {
|
||||
size_t num_columns_;
|
||||
};
|
||||
|
||||
class ArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
ArrayInterface array_interface_;
|
||||
|
||||
class Line {
|
||||
ArrayInterface array_interface_;
|
||||
size_t ridx_;
|
||||
public:
|
||||
Line(ArrayInterface array_interface, size_t ridx)
|
||||
: array_interface_{std::move(array_interface)}, ridx_{ridx} {}
|
||||
|
||||
size_t Size() const { return array_interface_.num_cols; }
|
||||
|
||||
COOTuple GetElement(size_t idx) const {
|
||||
return {ridx_, idx, array_interface_.GetElement(idx)};
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
ArrayAdapterBatch() = default;
|
||||
Line const GetLine(size_t idx) const {
|
||||
auto line = array_interface_.SliceRow(idx);
|
||||
return Line{line, idx};
|
||||
}
|
||||
|
||||
explicit ArrayAdapterBatch(ArrayInterface array_interface)
|
||||
: array_interface_{std::move(array_interface)} {}
|
||||
};
|
||||
|
||||
/**
|
||||
* Adapter for dense array on host, in Python that's `numpy.ndarray`. This is similar to
|
||||
* `DenseAdapter`, but supports __array_interface__ instead of raw pointers. An
|
||||
* advantage is this can handle various data type without making a copy.
|
||||
*/
|
||||
class ArrayAdapter : public detail::SingleBatchDataIter<ArrayAdapterBatch> {
|
||||
public:
|
||||
explicit ArrayAdapter(StringView array_interface) {
|
||||
auto j = Json::Load(array_interface);
|
||||
array_interface_ = ArrayInterface(get<Object const>(j));
|
||||
batch_ = ArrayAdapterBatch{array_interface_};
|
||||
}
|
||||
ArrayAdapterBatch const& Value() const override { return batch_; }
|
||||
size_t NumRows() const { return array_interface_.num_rows; }
|
||||
size_t NumColumns() const { return array_interface_.num_cols; }
|
||||
|
||||
private:
|
||||
ArrayAdapterBatch batch_;
|
||||
ArrayInterface array_interface_;
|
||||
};
|
||||
|
||||
class CSRArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
ArrayInterface indptr_;
|
||||
ArrayInterface indices_;
|
||||
ArrayInterface values_;
|
||||
|
||||
class Line {
|
||||
ArrayInterface indices_;
|
||||
ArrayInterface values_;
|
||||
size_t ridx_;
|
||||
|
||||
public:
|
||||
Line(ArrayInterface indices, ArrayInterface values, size_t ridx)
|
||||
: indices_{std::move(indices)}, values_{std::move(values)}, ridx_{ridx} {}
|
||||
|
||||
COOTuple GetElement(size_t idx) const {
|
||||
return {ridx_, indices_.GetElement<size_t>(idx), values_.GetElement(idx)};
|
||||
}
|
||||
size_t Size() const {
|
||||
return values_.num_rows * values_.num_cols;
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
CSRArrayAdapterBatch() = default;
|
||||
CSRArrayAdapterBatch(ArrayInterface indptr, ArrayInterface indices,
|
||||
ArrayInterface values)
|
||||
: indptr_{std::move(indptr)}, indices_{std::move(indices)},
|
||||
values_{std::move(values)} {}
|
||||
|
||||
Line const GetLine(size_t idx) const {
|
||||
auto begin_offset = indptr_.GetElement<size_t>(idx);
|
||||
auto end_offset = indptr_.GetElement<size_t>(idx + 1);
|
||||
auto indices = indices_.SliceOffset(begin_offset);
|
||||
auto values = values_.SliceOffset(begin_offset);
|
||||
values.num_cols = end_offset - begin_offset;
|
||||
values.num_rows = 1;
|
||||
indices.num_cols = values.num_cols;
|
||||
indices.num_rows = values.num_rows;
|
||||
return Line{indices, values, idx};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Adapter for CSR array on host, in Python that's `scipy.sparse.csr_matrix`. This is
|
||||
* similar to `CSRAdapter`, but supports __array_interface__ instead of raw pointers. An
|
||||
* advantage is this can handle various data type without making a copy.
|
||||
*/
|
||||
class CSRArrayAdapter : public detail::SingleBatchDataIter<CSRArrayAdapterBatch> {
|
||||
public:
|
||||
CSRArrayAdapter(StringView indptr, StringView indices, StringView values,
|
||||
size_t num_cols)
|
||||
: indptr_{indptr}, indices_{indices}, values_{values}, num_cols_{num_cols} {
|
||||
batch_ = CSRArrayAdapterBatch{indptr_, indices_, values_};
|
||||
}
|
||||
|
||||
CSRArrayAdapterBatch const& Value() const override {
|
||||
return batch_;
|
||||
}
|
||||
size_t NumRows() const {
|
||||
size_t size = indptr_.num_cols * indptr_.num_rows;
|
||||
size = size == 0 ? 0 : size - 1;
|
||||
return size;
|
||||
}
|
||||
size_t NumColumns() const { return num_cols_; }
|
||||
|
||||
private:
|
||||
CSRArrayAdapterBatch batch_;
|
||||
ArrayInterface indptr_;
|
||||
ArrayInterface indices_;
|
||||
ArrayInterface values_;
|
||||
size_t num_cols_;
|
||||
};
|
||||
|
||||
class CSCAdapterBatch : public detail::NoMetaInfo {
|
||||
public:
|
||||
CSCAdapterBatch(const size_t* col_ptr, const unsigned* row_idx,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2019 by Contributors
|
||||
* Copyright 2019-2021 by Contributors
|
||||
* \file array_interface.h
|
||||
* \brief View of __array_interface__
|
||||
*/
|
||||
@ -87,7 +87,7 @@ struct ArrayInterfaceErrors {
|
||||
}
|
||||
}
|
||||
|
||||
static std::string UnSupportedType(const char (&typestr)[3]) {
|
||||
static std::string UnSupportedType(StringView typestr) {
|
||||
return TypeStr(typestr[1]) + " is not supported.";
|
||||
}
|
||||
};
|
||||
@ -210,6 +210,7 @@ class ArrayInterfaceHandler {
|
||||
static_cast<bst_feature_t>(get<Integer const>(j_shape.at(1)))};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static common::Span<T> ExtractData(std::map<std::string, Json> const& column) {
|
||||
Validate(column);
|
||||
@ -257,16 +258,24 @@ class ArrayInterface {
|
||||
}
|
||||
|
||||
auto typestr = get<String const>(column.at("typestr"));
|
||||
type[0] = typestr.at(0);
|
||||
type[1] = typestr.at(1);
|
||||
type[2] = typestr.at(2);
|
||||
this->CheckType();
|
||||
this->AssignType(StringView{typestr});
|
||||
}
|
||||
|
||||
public:
|
||||
enum Type : std::int8_t { kF4, kF8, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 };
|
||||
|
||||
public:
|
||||
ArrayInterface() = default;
|
||||
explicit ArrayInterface(std::string const& str, bool allow_mask = true) {
|
||||
auto jinterface = Json::Load({str.c_str(), str.size()});
|
||||
explicit ArrayInterface(std::string const &str, bool allow_mask = true)
|
||||
: ArrayInterface{StringView{str.c_str(), str.size()}, allow_mask} {}
|
||||
|
||||
explicit ArrayInterface(std::map<std::string, Json> const &column,
|
||||
bool allow_mask = true) {
|
||||
this->Initialize(column, allow_mask);
|
||||
}
|
||||
|
||||
explicit ArrayInterface(StringView str, bool allow_mask = true) {
|
||||
auto jinterface = Json::Load(str);
|
||||
if (IsA<Object>(jinterface)) {
|
||||
this->Initialize(get<Object const>(jinterface), allow_mask);
|
||||
return;
|
||||
@ -279,71 +288,114 @@ class ArrayInterface {
|
||||
}
|
||||
}
|
||||
|
||||
explicit ArrayInterface(std::map<std::string, Json> const &column,
|
||||
bool allow_mask = true) {
|
||||
this->Initialize(column, allow_mask);
|
||||
}
|
||||
|
||||
void CheckType() const {
|
||||
if (type[1] == 'f' && type[2] == '4') {
|
||||
return;
|
||||
} else if (type[1] == 'f' && type[2] == '8') {
|
||||
return;
|
||||
} else if (type[1] == 'i' && type[2] == '1') {
|
||||
return;
|
||||
} else if (type[1] == 'i' && type[2] == '2') {
|
||||
return;
|
||||
} else if (type[1] == 'i' && type[2] == '4') {
|
||||
return;
|
||||
} else if (type[1] == 'i' && type[2] == '8') {
|
||||
return;
|
||||
} else if (type[1] == 'u' && type[2] == '1') {
|
||||
return;
|
||||
} else if (type[1] == 'u' && type[2] == '2') {
|
||||
return;
|
||||
} else if (type[1] == 'u' && type[2] == '4') {
|
||||
return;
|
||||
} else if (type[1] == 'u' && type[2] == '8') {
|
||||
return;
|
||||
void AssignType(StringView typestr) {
|
||||
if (typestr[1] == 'f' && typestr[2] == '4') {
|
||||
type = kF4;
|
||||
} else if (typestr[1] == 'f' && typestr[2] == '8') {
|
||||
type = kF8;
|
||||
} else if (typestr[1] == 'i' && typestr[2] == '1') {
|
||||
type = kI1;
|
||||
} else if (typestr[1] == 'i' && typestr[2] == '2') {
|
||||
type = kI2;
|
||||
} else if (typestr[1] == 'i' && typestr[2] == '4') {
|
||||
type = kI4;
|
||||
} else if (typestr[1] == 'i' && typestr[2] == '8') {
|
||||
type = kI8;
|
||||
} else if (typestr[1] == 'u' && typestr[2] == '1') {
|
||||
type = kU1;
|
||||
} else if (typestr[1] == 'u' && typestr[2] == '2') {
|
||||
type = kU2;
|
||||
} else if (typestr[1] == 'u' && typestr[2] == '4') {
|
||||
type = kU4;
|
||||
} else if (typestr[1] == 'u' && typestr[2] == '8') {
|
||||
type = kU8;
|
||||
} else {
|
||||
LOG(FATAL) << ArrayInterfaceErrors::UnSupportedType(type);
|
||||
LOG(FATAL) << ArrayInterfaceErrors::UnSupportedType(typestr);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE float GetElement(size_t idx) const {
|
||||
XGBOOST_DEVICE ArrayInterface SliceOffset(size_t offset) const {
|
||||
void* p_values;
|
||||
switch (type) {
|
||||
case kF4:
|
||||
p_values = reinterpret_cast<float *>(data) + offset;
|
||||
break;
|
||||
case kF8:
|
||||
p_values = reinterpret_cast<double *>(data) + offset;
|
||||
break;
|
||||
case kI1:
|
||||
p_values = reinterpret_cast<int8_t *>(data) + offset;
|
||||
break;
|
||||
case kI2:
|
||||
p_values = reinterpret_cast<int16_t *>(data) + offset;
|
||||
break;
|
||||
case kI4:
|
||||
p_values = reinterpret_cast<int32_t *>(data) + offset;
|
||||
break;
|
||||
case kI8:
|
||||
p_values = reinterpret_cast<int64_t *>(data) + offset;
|
||||
break;
|
||||
case kU1:
|
||||
p_values = reinterpret_cast<uint8_t *>(data) + offset;
|
||||
break;
|
||||
case kU2:
|
||||
p_values = reinterpret_cast<uint16_t *>(data) + offset;
|
||||
break;
|
||||
case kU4:
|
||||
p_values = reinterpret_cast<uint32_t *>(data) + offset;
|
||||
break;
|
||||
case kU8:
|
||||
p_values = reinterpret_cast<uint64_t *>(data) + offset;
|
||||
break;
|
||||
}
|
||||
ArrayInterface ret = *this;
|
||||
ret.data = p_values;
|
||||
return ret;
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE ArrayInterface SliceRow(size_t idx) const {
|
||||
size_t offset = idx * num_cols;
|
||||
auto ret = this->SliceOffset(offset);
|
||||
ret.num_rows = 1;
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T = float>
|
||||
XGBOOST_DEVICE T GetElement(size_t idx) const {
|
||||
SPAN_CHECK(idx < num_cols * num_rows);
|
||||
if (type[1] == 'f' && type[2] == '4') {
|
||||
switch (type) {
|
||||
case kF4:
|
||||
return reinterpret_cast<float*>(data)[idx];
|
||||
} else if (type[1] == 'f' && type[2] == '8') {
|
||||
case kF8:
|
||||
return reinterpret_cast<double*>(data)[idx];
|
||||
} else if (type[1] == 'i' && type[2] == '1') {
|
||||
case kI1:
|
||||
return reinterpret_cast<int8_t*>(data)[idx];
|
||||
} else if (type[1] == 'i' && type[2] == '2') {
|
||||
case kI2:
|
||||
return reinterpret_cast<int16_t*>(data)[idx];
|
||||
} else if (type[1] == 'i' && type[2] == '4') {
|
||||
case kI4:
|
||||
return reinterpret_cast<int32_t*>(data)[idx];
|
||||
} else if (type[1] == 'i' && type[2] == '8') {
|
||||
case kI8:
|
||||
return reinterpret_cast<int64_t*>(data)[idx];
|
||||
} else if (type[1] == 'u' && type[2] == '1') {
|
||||
case kU1:
|
||||
return reinterpret_cast<uint8_t*>(data)[idx];
|
||||
} else if (type[1] == 'u' && type[2] == '2') {
|
||||
case kU2:
|
||||
return reinterpret_cast<uint16_t*>(data)[idx];
|
||||
} else if (type[1] == 'u' && type[2] == '4') {
|
||||
case kU4:
|
||||
return reinterpret_cast<uint32_t*>(data)[idx];
|
||||
} else if (type[1] == 'u' && type[2] == '8') {
|
||||
case kU8:
|
||||
return reinterpret_cast<uint64_t*>(data)[idx];
|
||||
} else {
|
||||
SPAN_CHECK(false);
|
||||
return 0;
|
||||
}
|
||||
SPAN_CHECK(false);
|
||||
return reinterpret_cast<float*>(data)[idx];
|
||||
}
|
||||
|
||||
RBitField8 valid;
|
||||
bst_row_t num_rows;
|
||||
bst_feature_t num_cols;
|
||||
void* data;
|
||||
char type[3];
|
||||
|
||||
Type type;
|
||||
};
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2019 by XGBoost Contributors
|
||||
* Copyright 2019-2021 by XGBoost Contributors
|
||||
*
|
||||
* \file data.cu
|
||||
* \brief Handles setting metainfo from array interface.
|
||||
@ -45,15 +45,15 @@ auto SetDeviceToPtr(void *ptr) {
|
||||
} // anonymous namespace
|
||||
|
||||
void CopyGroupInfoImpl(ArrayInterface column, std::vector<bst_group_t>* out) {
|
||||
CHECK(column.type[1] == 'i' || column.type[1] == 'u')
|
||||
<< "Expected integer metainfo";
|
||||
CHECK(column.type != ArrayInterface::kF4 && column.type != ArrayInterface::kF8)
|
||||
<< "Expected integer for group info.";
|
||||
|
||||
auto ptr_device = SetDeviceToPtr(column.data);
|
||||
dh::TemporaryArray<bst_group_t> temp(column.num_rows);
|
||||
auto d_tmp = temp.data();
|
||||
|
||||
dh::LaunchN(ptr_device, column.num_rows, [=] __device__(size_t idx) {
|
||||
d_tmp[idx] = column.GetElement(idx);
|
||||
d_tmp[idx] = column.GetElement<size_t>(idx);
|
||||
});
|
||||
auto length = column.num_rows;
|
||||
out->resize(length + 1);
|
||||
@ -103,15 +103,15 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
||||
auto it = dh::MakeTransformIterator<uint32_t>(
|
||||
thrust::make_counting_iterator(0ul),
|
||||
[array_interface] __device__(size_t i) {
|
||||
return static_cast<uint32_t>(array_interface.GetElement(i));
|
||||
return array_interface.GetElement<uint32_t>(i);
|
||||
});
|
||||
dh::caching_device_vector<bool> flag(1);
|
||||
auto d_flag = dh::ToSpan(flag);
|
||||
auto d = SetDeviceToPtr(array_interface.data);
|
||||
dh::LaunchN(d, 1, [=] __device__(size_t) { d_flag[0] = true; });
|
||||
dh::LaunchN(d, array_interface.num_rows - 1, [=] __device__(size_t i) {
|
||||
if (static_cast<uint32_t>(array_interface.GetElement(i)) >
|
||||
static_cast<uint32_t>(array_interface.GetElement(i + 1))) {
|
||||
if (array_interface.GetElement<uint32_t>(i) >
|
||||
array_interface.GetElement<uint32_t>(i + 1)) {
|
||||
d_flag[0] = false;
|
||||
}
|
||||
});
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014-2020 by Contributors
|
||||
* Copyright 2014-2021 by Contributors
|
||||
* \file gbtree.cc
|
||||
* \brief gradient boosted tree implementation.
|
||||
* \author Tianqi Chen
|
||||
@ -265,15 +265,34 @@ class GBTree : public GradientBooster {
|
||||
bool training,
|
||||
unsigned ntree_limit) override;
|
||||
|
||||
void InplacePredict(dmlc::any const &x, float missing,
|
||||
PredictionCacheEntry *out_preds,
|
||||
uint32_t layer_begin,
|
||||
unsigned layer_end) const override {
|
||||
void InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
||||
float missing, PredictionCacheEntry *out_preds,
|
||||
uint32_t layer_begin, unsigned layer_end) const override {
|
||||
CHECK(configured_);
|
||||
uint32_t tree_begin, tree_end;
|
||||
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
|
||||
this->GetPredictor()->InplacePredict(x, model_, missing, out_preds,
|
||||
tree_begin, tree_end);
|
||||
std::tie(tree_begin, tree_end) =
|
||||
detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
|
||||
std::vector<Predictor const *> predictors{
|
||||
cpu_predictor_.get(),
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
gpu_predictor_.get()
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
};
|
||||
StringView msg{"Unsupported data type for inplace predict."};
|
||||
if (tparam_.predictor == PredictorType::kAuto) {
|
||||
// Try both predictor implementations
|
||||
for (auto const &p : predictors) {
|
||||
if (p && p->InplacePredict(x, p_m, model_, missing, out_preds,
|
||||
tree_begin, tree_end)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
LOG(FATAL) << msg;
|
||||
} else {
|
||||
bool success = this->GetPredictor()->InplacePredict(
|
||||
x, p_m, model_, missing, out_preds, tree_begin, tree_end);
|
||||
CHECK(success) << msg;
|
||||
}
|
||||
}
|
||||
|
||||
void PredictInstance(const SparsePage::Inst& inst,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014-2020 by Contributors
|
||||
* Copyright 2014-2021 by Contributors
|
||||
* \file learner.cc
|
||||
* \brief Implementation of learning algorithm.
|
||||
* \author Tianqi Chen
|
||||
@ -1110,23 +1110,30 @@ class LearnerImpl : public LearnerIO {
|
||||
CHECK(!this->need_configuration_);
|
||||
return this->gbm_->BoostedRounds();
|
||||
}
|
||||
uint32_t Groups() const override {
|
||||
CHECK(!this->need_configuration_);
|
||||
return this->learner_model_param_.num_output_group;
|
||||
}
|
||||
|
||||
XGBAPIThreadLocalEntry& GetThreadLocal() const override {
|
||||
return (*LearnerAPIThreadLocalStore::Get())[this];
|
||||
}
|
||||
|
||||
void InplacePredict(dmlc::any const &x, std::string const &type,
|
||||
float missing, HostDeviceVector<bst_float> **out_preds,
|
||||
uint32_t layer_begin, uint32_t layer_end) override {
|
||||
void InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
||||
PredictionType type, float missing,
|
||||
HostDeviceVector<bst_float> **out_preds,
|
||||
uint32_t iteration_begin,
|
||||
uint32_t iteration_end) override {
|
||||
this->Configure();
|
||||
auto& out_predictions = this->GetThreadLocal().prediction_entry;
|
||||
this->gbm_->InplacePredict(x, missing, &out_predictions, layer_begin,
|
||||
layer_end);
|
||||
if (type == "value") {
|
||||
this->gbm_->InplacePredict(x, p_m, missing, &out_predictions,
|
||||
iteration_begin, iteration_end);
|
||||
if (type == PredictionType::kValue) {
|
||||
obj_->PredTransform(&out_predictions.predictions);
|
||||
} else if (type == "margin") {
|
||||
} else if (type == PredictionType::kMargin) {
|
||||
// do nothing
|
||||
} else {
|
||||
LOG(FATAL) << "Unsupported prediction type:" << type;
|
||||
LOG(FATAL) << "Unsupported prediction type:" << static_cast<int>(type);
|
||||
}
|
||||
*out_preds = &out_predictions.predictions;
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright by Contributors 2017-2020
|
||||
* Copyright by Contributors 2017-2021
|
||||
*/
|
||||
#include <dmlc/omp.h>
|
||||
#include <dmlc/any.h>
|
||||
@ -287,7 +287,7 @@ class CPUPredictor : public Predictor {
|
||||
}
|
||||
|
||||
template <typename Adapter>
|
||||
void DispatchedInplacePredict(dmlc::any const &x,
|
||||
void DispatchedInplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
||||
const gbm::GBTreeModel &model, float missing,
|
||||
PredictionCacheEntry *out_preds,
|
||||
uint32_t tree_begin, uint32_t tree_end) const {
|
||||
@ -295,33 +295,44 @@ class CPUPredictor : public Predictor {
|
||||
auto m = dmlc::get<std::shared_ptr<Adapter>>(x);
|
||||
CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature)
|
||||
<< "Number of columns in data must equal to trained model.";
|
||||
if (p_m) {
|
||||
p_m->Info().num_row_ = m->NumRows();
|
||||
this->InitOutPredictions(p_m->Info(), &(out_preds->predictions), model);
|
||||
} else {
|
||||
MetaInfo info;
|
||||
info.num_col_ = m->NumColumns();
|
||||
info.num_row_ = m->NumRows();
|
||||
this->InitOutPredictions(info, &(out_preds->predictions), model);
|
||||
std::vector<Entry> workspace(info.num_col_ * 8 * threads);
|
||||
}
|
||||
std::vector<Entry> workspace(m->NumColumns() * 8 * threads);
|
||||
auto &predictions = out_preds->predictions.HostVector();
|
||||
std::vector<RegTree::FVec> thread_temp;
|
||||
InitThreadTemp(threads*kBlockOfRowsSize, model.learner_model_param->num_feature,
|
||||
&thread_temp);
|
||||
PredictBatchByBlockOfRowsKernel<AdapterView<Adapter>,
|
||||
kBlockOfRowsSize>(AdapterView<Adapter>(
|
||||
m.get(), missing, common::Span<Entry>{workspace}),
|
||||
InitThreadTemp(threads * kBlockOfRowsSize,
|
||||
model.learner_model_param->num_feature, &thread_temp);
|
||||
PredictBatchByBlockOfRowsKernel<AdapterView<Adapter>, kBlockOfRowsSize>(
|
||||
AdapterView<Adapter>(m.get(), missing, common::Span<Entry>{workspace}),
|
||||
&predictions, model, tree_begin, tree_end, &thread_temp);
|
||||
}
|
||||
|
||||
void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model,
|
||||
float missing, PredictionCacheEntry *out_preds,
|
||||
uint32_t tree_begin, unsigned tree_end) const override {
|
||||
bool InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
||||
const gbm::GBTreeModel &model, float missing,
|
||||
PredictionCacheEntry *out_preds, uint32_t tree_begin,
|
||||
unsigned tree_end) const override {
|
||||
if (x.type() == typeid(std::shared_ptr<data::DenseAdapter>)) {
|
||||
this->DispatchedInplacePredict<data::DenseAdapter>(
|
||||
x, model, missing, out_preds, tree_begin, tree_end);
|
||||
x, p_m, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else if (x.type() == typeid(std::shared_ptr<data::CSRAdapter>)) {
|
||||
this->DispatchedInplacePredict<data::CSRAdapter>(
|
||||
x, model, missing, out_preds, tree_begin, tree_end);
|
||||
x, p_m, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else if (x.type() == typeid(std::shared_ptr<data::ArrayAdapter>)) {
|
||||
this->DispatchedInplacePredict<data::ArrayAdapter> (
|
||||
x, p_m, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else if (x.type() == typeid(std::shared_ptr<data::CSRArrayAdapter>)) {
|
||||
this->DispatchedInplacePredict<data::CSRArrayAdapter> (
|
||||
x, p_m, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else {
|
||||
LOG(FATAL) << "Data type is not supported by CPU Predictor.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void PredictInstance(const SparsePage::Inst& inst,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2017-2020 by Contributors
|
||||
* Copyright 2017-2021 by Contributors
|
||||
*/
|
||||
#include <thrust/copy.h>
|
||||
#include <thrust/device_ptr.h>
|
||||
@ -644,7 +644,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
}
|
||||
|
||||
template <typename Adapter, typename Loader>
|
||||
void DispatchedInplacePredict(dmlc::any const &x,
|
||||
void DispatchedInplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
||||
const gbm::GBTreeModel &model, float,
|
||||
PredictionCacheEntry *out_preds,
|
||||
uint32_t tree_begin, uint32_t tree_end) const {
|
||||
@ -659,16 +659,20 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
CHECK_EQ(this->generic_param_->gpu_id, m->DeviceIdx())
|
||||
<< "XGBoost is running on device: " << this->generic_param_->gpu_id << ", "
|
||||
<< "but data is on: " << m->DeviceIdx();
|
||||
if (p_m) {
|
||||
p_m->Info().num_row_ = m->NumRows();
|
||||
this->InitOutPredictions(p_m->Info(), &(out_preds->predictions), model);
|
||||
} else {
|
||||
MetaInfo info;
|
||||
info.num_col_ = m->NumColumns();
|
||||
info.num_row_ = m->NumRows();
|
||||
this->InitOutPredictions(info, &(out_preds->predictions), model);
|
||||
}
|
||||
|
||||
const uint32_t BLOCK_THREADS = 128;
|
||||
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(info.num_row_, BLOCK_THREADS));
|
||||
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(m->NumRows(), BLOCK_THREADS));
|
||||
|
||||
size_t shared_memory_bytes =
|
||||
SharedMemoryBytes<BLOCK_THREADS>(info.num_col_, max_shared_memory_bytes);
|
||||
SharedMemoryBytes<BLOCK_THREADS>(m->NumColumns(), max_shared_memory_bytes);
|
||||
bool use_shared = shared_memory_bytes != 0;
|
||||
size_t entry_start = 0;
|
||||
|
||||
@ -680,23 +684,25 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
d_model.categories_tree_segments.ConstDeviceSpan(),
|
||||
d_model.categories_node_segments.ConstDeviceSpan(),
|
||||
d_model.categories.ConstDeviceSpan(), tree_begin, tree_end, m->NumColumns(),
|
||||
info.num_row_, entry_start, use_shared, output_groups);
|
||||
m->NumRows(), entry_start, use_shared, output_groups);
|
||||
}
|
||||
|
||||
void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model,
|
||||
float missing, PredictionCacheEntry *out_preds,
|
||||
uint32_t tree_begin, unsigned tree_end) const override {
|
||||
bool InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
||||
const gbm::GBTreeModel &model, float missing,
|
||||
PredictionCacheEntry *out_preds, uint32_t tree_begin,
|
||||
unsigned tree_end) const override {
|
||||
if (x.type() == typeid(std::shared_ptr<data::CupyAdapter>)) {
|
||||
this->DispatchedInplacePredict<
|
||||
data::CupyAdapter, DeviceAdapterLoader<data::CupyAdapterBatch>>(
|
||||
x, model, missing, out_preds, tree_begin, tree_end);
|
||||
x, p_m, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else if (x.type() == typeid(std::shared_ptr<data::CudfAdapter>)) {
|
||||
this->DispatchedInplacePredict<
|
||||
data::CudfAdapter, DeviceAdapterLoader<data::CudfAdapterBatch>>(
|
||||
x, model, missing, out_preds, tree_begin, tree_end);
|
||||
x, p_m, model, missing, out_preds, tree_begin, tree_end);
|
||||
} else {
|
||||
LOG(FATAL) << "Only CuPy and CuDF are supported by GPU Predictor.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void PredictContribution(DMatrix* p_fmat,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2020 by Contributors
|
||||
* Copyright 2020-2021 by Contributors
|
||||
*/
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
@ -104,21 +104,24 @@ void TestInplacePrediction(dmlc::any x, std::string predictor,
|
||||
}
|
||||
|
||||
HostDeviceVector<float> *p_out_predictions_0{nullptr};
|
||||
learner->InplacePredict(x, "margin", std::numeric_limits<float>::quiet_NaN(),
|
||||
learner->InplacePredict(x, nullptr, PredictionType::kMargin,
|
||||
std::numeric_limits<float>::quiet_NaN(),
|
||||
&p_out_predictions_0, 0, 2);
|
||||
CHECK(p_out_predictions_0);
|
||||
HostDeviceVector<float> predict_0 (p_out_predictions_0->Size());
|
||||
predict_0.Copy(*p_out_predictions_0);
|
||||
|
||||
HostDeviceVector<float> *p_out_predictions_1{nullptr};
|
||||
learner->InplacePredict(x, "margin", std::numeric_limits<float>::quiet_NaN(),
|
||||
learner->InplacePredict(x, nullptr, PredictionType::kMargin,
|
||||
std::numeric_limits<float>::quiet_NaN(),
|
||||
&p_out_predictions_1, 2, 4);
|
||||
CHECK(p_out_predictions_1);
|
||||
HostDeviceVector<float> predict_1 (p_out_predictions_1->Size());
|
||||
predict_1.Copy(*p_out_predictions_1);
|
||||
|
||||
HostDeviceVector<float>* p_out_predictions{nullptr};
|
||||
learner->InplacePredict(x, "margin", std::numeric_limits<float>::quiet_NaN(),
|
||||
learner->InplacePredict(x, nullptr, PredictionType::kMargin,
|
||||
std::numeric_limits<float>::quiet_NaN(),
|
||||
&p_out_predictions, 0, 4);
|
||||
|
||||
auto& h_pred = p_out_predictions->HostVector();
|
||||
|
||||
@ -11,8 +11,7 @@ import testing as tm
|
||||
class TestDeviceQuantileDMatrix:
|
||||
def test_dmatrix_numpy_init(self):
|
||||
data = np.random.randn(5, 5)
|
||||
with pytest.raises(TypeError,
|
||||
match='is not supported for DeviceQuantileDMatrix'):
|
||||
with pytest.raises(TypeError, match='is not supported'):
|
||||
xgb.DeviceQuantileDMatrix(data, np.ones(5, dtype=np.float64))
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
|
||||
@ -141,6 +141,13 @@ class TestGPUPredict:
|
||||
assert np.allclose(cpu_train_score, gpu_train_score)
|
||||
assert np.allclose(cpu_test_score, gpu_test_score)
|
||||
|
||||
def run_inplace_base_margin(self, booster, dtrain, X, base_margin):
|
||||
import cupy as cp
|
||||
dtrain.set_info(base_margin=base_margin)
|
||||
from_inplace = booster.inplace_predict(data=X, base_margin=base_margin)
|
||||
from_dmatrix = booster.predict(dtrain)
|
||||
cp.testing.assert_allclose(from_inplace, from_dmatrix)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_inplace_predict_cupy(self):
|
||||
import cupy as cp
|
||||
@ -175,6 +182,9 @@ class TestGPUPredict:
|
||||
for i in range(10):
|
||||
run_threaded_predict(X, rows, predict_dense)
|
||||
|
||||
base_margin = cp_rng.randn(rows)
|
||||
self.run_inplace_base_margin(booster, dtrain, X, base_margin)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_inplace_predict_cudf(self):
|
||||
import cupy as cp
|
||||
@ -208,6 +218,9 @@ class TestGPUPredict:
|
||||
for i in range(10):
|
||||
run_threaded_predict(X, rows, predict_df)
|
||||
|
||||
base_margin = cudf.Series(rng.randn(rows))
|
||||
self.run_inplace_base_margin(booster, dtrain, X, base_margin)
|
||||
|
||||
@given(strategies.integers(1, 10),
|
||||
tm.dataset_strategy, shap_parameter_strategy)
|
||||
@settings(deadline=None)
|
||||
|
||||
@ -80,20 +80,28 @@ def test_predict_leaf():
|
||||
|
||||
class TestInplacePredict:
|
||||
'''Tests for running inplace prediction'''
|
||||
def test_predict(self):
|
||||
rows = 1000
|
||||
cols = 10
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
cls.rows = 100
|
||||
cls.cols = 10
|
||||
|
||||
np.random.seed(1994)
|
||||
cls.rng = np.random.RandomState(1994)
|
||||
|
||||
X = np.random.randn(rows, cols)
|
||||
y = np.random.randn(rows)
|
||||
dtrain = xgb.DMatrix(X, y)
|
||||
cls.X = cls.rng.randn(cls.rows, cls.cols)
|
||||
cls.y = cls.rng.randn(cls.rows)
|
||||
|
||||
booster = xgb.train({'tree_method': 'hist'},
|
||||
dtrain = xgb.DMatrix(cls.X, cls.y)
|
||||
|
||||
cls.booster = xgb.train({'tree_method': 'hist'},
|
||||
dtrain, num_boost_round=10)
|
||||
|
||||
test = xgb.DMatrix(X[:10, ...])
|
||||
cls.test = xgb.DMatrix(cls.X[:10, ...])
|
||||
|
||||
def test_predict(self):
|
||||
booster = self.booster
|
||||
X = self.X
|
||||
test = self.test
|
||||
|
||||
predt_from_array = booster.inplace_predict(X[:10, ...])
|
||||
predt_from_dmatrix = booster.predict(test)
|
||||
|
||||
@ -111,7 +119,7 @@ class TestInplacePredict:
|
||||
return np.all(copied_predt == inplace_predt)
|
||||
|
||||
for i in range(10):
|
||||
run_threaded_predict(X, rows, predict_dense)
|
||||
run_threaded_predict(X, self.rows, predict_dense)
|
||||
|
||||
def predict_csr(x):
|
||||
inplace_predt = booster.inplace_predict(sparse.csr_matrix(x))
|
||||
@ -120,4 +128,14 @@ class TestInplacePredict:
|
||||
return np.all(copied_predt == inplace_predt)
|
||||
|
||||
for i in range(10):
|
||||
run_threaded_predict(X, rows, predict_csr)
|
||||
run_threaded_predict(X, self.rows, predict_csr)
|
||||
|
||||
def test_base_margin(self):
|
||||
booster = self.booster
|
||||
|
||||
base_margin = self.rng.randn(self.rows)
|
||||
from_inplace = booster.inplace_predict(data=self.X, base_margin=base_margin)
|
||||
|
||||
dtrain = xgb.DMatrix(self.X, self.y, base_margin=base_margin)
|
||||
from_dmatrix = booster.predict(dtrain)
|
||||
np.testing.assert_allclose(from_dmatrix, from_inplace)
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
import testing as tm
|
||||
import pytest
|
||||
|
||||
try:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user