diff --git a/include/xgboost/gbm.h b/include/xgboost/gbm.h index 9de6e04f2..a49377521 100644 --- a/include/xgboost/gbm.h +++ b/include/xgboost/gbm.h @@ -1,5 +1,5 @@ /*! - * Copyright 2014-2020 by Contributors + * Copyright 2014-2021 by Contributors * \file gbm.h * \brief Interface of gradient booster, * that learns through gradient statistics. @@ -118,7 +118,7 @@ class GradientBooster : public Model, public Configurable { * \param layer_begin (Optional) Begining of boosted tree layer used for prediction. * \param layer_end (Optional) End of booster layer. 0 means do not limit trees. */ - virtual void InplacePredict(dmlc::any const &, float, + virtual void InplacePredict(dmlc::any const &, std::shared_ptr, float, PredictionCacheEntry*, uint32_t, uint32_t) const { diff --git a/include/xgboost/json.h b/include/xgboost/json.h index a10941942..db464e052 100644 --- a/include/xgboost/json.h +++ b/include/xgboost/json.h @@ -308,6 +308,7 @@ struct StringView { public: StringView() = default; StringView(CharT const* str, size_t size) : str_{str}, size_{size} {} + explicit StringView(std::string const& str): str_{str.c_str()}, size_{str.size()} {} explicit StringView(CharT const* str) : str_{str}, size_{Traits::length(str)} {} CharT const& operator[](size_t p) const { return str_[p]; } diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index b98c4fdb9..d2bd51080 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -1,5 +1,5 @@ /*! - * Copyright 2015-2020 by Contributors + * Copyright 2015-2021 by Contributors * \file learner.h * \brief Learner interface that integrates objective, gbm and evaluation together. * This is the user facing XGBoost training module. @@ -30,6 +30,15 @@ class ObjFunction; class DMatrix; class Json; +enum class PredictionType : std::uint8_t { // NOLINT + kValue = 0, + kMargin = 1, + kContribution = 2, + kApproxContribution = 3, + kInteraction = 4, + kLeaf = 5 +}; + /*! \brief entry to to easily hold returning information */ struct XGBAPIThreadLocalEntry { /*! \brief result holder for returning string */ @@ -42,7 +51,10 @@ struct XGBAPIThreadLocalEntry { std::vector ret_vec_float; /*! \brief temp variable of gradient pairs. */ std::vector tmp_gpair; + /*! \brief Temp variable for returing prediction result. */ PredictionCacheEntry prediction_entry; + /*! \brief Temp variable for returing prediction shape. */ + std::vector prediction_shape; }; /*! @@ -123,13 +135,17 @@ class Learner : public Model, public Configurable, public dmlc::Serializable { * \brief Inplace prediction. * * \param x A type erased data adapter. + * \param p_m An optional Proxy DMatrix object storing meta info like + * base margin. Can be nullptr. * \param type Prediction type. * \param missing Missing value in the data. * \param [in,out] out_preds Pointer to output prediction vector. - * \param layer_begin (Optional) Begining of boosted tree layer used for prediction. - * \param layer_end (Optional) End of booster layer. 0 means do not limit trees. + * \param layer_begin Begining of boosted tree layer used for prediction. + * \param layer_end End of booster layer. 0 means do not limit trees. */ - virtual void InplacePredict(dmlc::any const& x, std::string const& type, + virtual void InplacePredict(dmlc::any const &x, + std::shared_ptr p_m, + PredictionType type, float missing, HostDeviceVector **out_preds, uint32_t layer_begin, uint32_t layer_end) = 0; @@ -138,6 +154,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable { * \brief Get number of boosted rounds from gradient booster. */ virtual int32_t BoostedRounds() const = 0; + virtual uint32_t Groups() const = 0; void LoadModel(Json const& in) override = 0; void SaveModel(Json* out) const override = 0; diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index 5ab734359..442cf91dc 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -1,5 +1,5 @@ /*! - * Copyright 2017-2020 by Contributors + * Copyright 2017-2021 by Contributors * \file predictor.h * \brief Interface of predictor, * performs predictions for a gradient booster. @@ -142,10 +142,14 @@ class Predictor { * \param [in,out] out_preds The output preds. * \param tree_begin (Optional) Begining of boosted trees used for prediction. * \param tree_end (Optional) End of booster trees. 0 means do not limit trees. + * + * \return True if the data can be handled by current predictor, false otherwise. */ - virtual void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model, - float missing, PredictionCacheEntry *out_preds, - uint32_t tree_begin = 0, uint32_t tree_end = 0) const = 0; + virtual bool InplacePredict(dmlc::any const &x, std::shared_ptr p_m, + const gbm::GBTreeModel &model, float missing, + PredictionCacheEntry *out_preds, + uint32_t tree_begin = 0, + uint32_t tree_end = 0) const = 0; /** * \brief online prediction function, predict score for one instance at a time * NOTE: use the batch prediction interface if possible, batch prediction is diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index b37a908a9..765fea4d7 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -58,21 +58,23 @@ CallbackEnv = collections.namedtuple( "evaluation_result_list"]) -def from_pystr_to_cstr(data): - """Convert a list of Python str to C pointer +def from_pystr_to_cstr(data: Union[str, List[str]]): + """Convert a Python str or list of Python str to C pointer Parameters ---------- - data : list - list of str + data + str or list of str """ - if not isinstance(data, list): - raise NotImplementedError - pointers = (ctypes.c_char_p * len(data))() - data = [bytes(d, 'utf-8') for d in data] - pointers[:] = data - return pointers + if isinstance(data, str): + return bytes(data, "utf-8") + if isinstance(data, list): + pointers = (ctypes.c_char_p * len(data))() + data = [bytes(d, 'utf-8') for d in data] + pointers[:] = data + return pointers + raise TypeError() def from_cstr_to_pystr(data, length): @@ -190,21 +192,40 @@ def _check_call(ret): raise XGBoostError(py_str(_LIB.XGBGetLastError())) -def ctypes2numpy(cptr, length, dtype) -> np.ndarray: - """Convert a ctypes pointer array to a numpy array.""" - NUMPY_TO_CTYPES_MAPPING = { +def _numpy2ctypes_type(dtype): + _NUMPY_TO_CTYPES_MAPPING = { np.float32: ctypes.c_float, + np.float64: ctypes.c_double, np.uint32: ctypes.c_uint, + np.uint64: ctypes.c_uint64, + np.int32: ctypes.c_int32, + np.int64: ctypes.c_int64, } - if dtype not in NUMPY_TO_CTYPES_MAPPING: - raise RuntimeError('Supported types: {}'.format( - NUMPY_TO_CTYPES_MAPPING.keys())) - ctype = NUMPY_TO_CTYPES_MAPPING[dtype] + if np.intc is not np.int32: # Windows + _NUMPY_TO_CTYPES_MAPPING[np.intc] = _NUMPY_TO_CTYPES_MAPPING[np.int32] + if dtype not in _NUMPY_TO_CTYPES_MAPPING.keys(): + raise TypeError( + f"Supported types: {_NUMPY_TO_CTYPES_MAPPING.keys()}, got: {dtype}" + ) + return _NUMPY_TO_CTYPES_MAPPING[dtype] + + +def _array_interface(data: np.ndarray) -> bytes: + interface = data.__array_interface__ + if "mask" in interface: + interface["mask"] = interface["mask"].__array_interface__ + interface_str = bytes(json.dumps(interface, indent=2), "utf-8") + return interface_str + + +def ctypes2numpy(cptr, length, dtype): + """Convert a ctypes pointer array to a numpy array.""" + ctype = _numpy2ctypes_type(dtype) if not isinstance(cptr, ctypes.POINTER(ctype)): - raise RuntimeError('expected {} pointer'.format(ctype)) + raise RuntimeError("expected {} pointer".format(ctype)) res = np.zeros(length, dtype=dtype) if not ctypes.memmove(res.ctypes.data, cptr, length * res.strides[0]): - raise RuntimeError('memmove failed') + raise RuntimeError("memmove failed") return res @@ -214,25 +235,21 @@ def ctypes2cupy(cptr, length, dtype): import cupy from cupy.cuda.memory import MemoryPointer from cupy.cuda.memory import UnownedMemory - CUPY_TO_CTYPES_MAPPING = { - cupy.float32: ctypes.c_float, - cupy.uint32: ctypes.c_uint - } + + CUPY_TO_CTYPES_MAPPING = {cupy.float32: ctypes.c_float, cupy.uint32: ctypes.c_uint} if dtype not in CUPY_TO_CTYPES_MAPPING.keys(): - raise RuntimeError('Supported types: {}'.format( - CUPY_TO_CTYPES_MAPPING.keys() - )) + raise RuntimeError("Supported types: {}".format(CUPY_TO_CTYPES_MAPPING.keys())) addr = ctypes.cast(cptr, ctypes.c_void_p).value # pylint: disable=c-extension-no-member,no-member device = cupy.cuda.runtime.pointerGetAttributes(addr).device # The owner field is just used to keep the memory alive with ref count. As # unowned's life time is scoped within this function we don't need that. unownd = UnownedMemory( - addr, length.value * ctypes.sizeof(CUPY_TO_CTYPES_MAPPING[dtype]), - owner=None) + addr, length * ctypes.sizeof(CUPY_TO_CTYPES_MAPPING[dtype]), owner=None + ) memptr = MemoryPointer(unownd, 0) # pylint: disable=unexpected-keyword-arg - mem = cupy.ndarray((length.value, ), dtype=dtype, memptr=memptr) + mem = cupy.ndarray((length,), dtype=dtype, memptr=memptr) assert mem.device.id == device arr = cupy.array(mem, copy=True) return arr @@ -256,28 +273,29 @@ def c_str(string): def c_array(ctype, values): """Convert a python string to c array.""" - if (isinstance(values, np.ndarray) - and values.dtype.itemsize == ctypes.sizeof(ctype)): + if isinstance(values, np.ndarray) and values.dtype.itemsize == ctypes.sizeof(ctype): return (ctype * len(values)).from_buffer_copy(values) return (ctype * len(values))(*values) +def _prediction_output(shape, dims, predts, is_cuda): + arr_shape: np.ndarray = ctypes2numpy(shape, dims.value, np.uint64) + length = int(np.prod(arr_shape)) + if is_cuda: + arr_predict = ctypes2cupy(predts, length, np.float32) + else: + arr_predict: np.ndarray = ctypes2numpy(predts, length, np.float32) + arr_predict = arr_predict.reshape(arr_shape) + return arr_predict + + class DataIter: - '''The interface for user defined data iterator. Currently is only - supported by Device DMatrix. + '''The interface for user defined data iterator. Currently is only supported by Device + DMatrix. - Parameters - ---------- - - rows : int - Total number of rows combining all batches. - cols : int - Number of columns for each batch. ''' def __init__(self): - proxy_handle = ctypes.c_void_p() - _check_call(_LIB.XGProxyDMatrixCreate(ctypes.byref(proxy_handle))) - self._handle = DeviceQuantileDMatrix(proxy_handle) + self._handle = _ProxyDMatrix() self.exception = None @property @@ -300,12 +318,7 @@ class DataIter: if self.exception is not None: return 0 - def data_handle(data, label=None, weight=None, base_margin=None, - group=None, - qid=None, - label_lower_bound=None, label_upper_bound=None, - feature_names=None, feature_types=None, - feature_weights=None): + def data_handle(data, feature_names=None, feature_types=None, **kwargs): from .data import dispatch_device_quantile_dmatrix_set_data from .data import _device_quantile_transform data, feature_names, feature_types = _device_quantile_transform( @@ -313,16 +326,9 @@ class DataIter: ) dispatch_device_quantile_dmatrix_set_data(self.proxy, data) self.proxy.set_info( - label=label, - weight=weight, - base_margin=base_margin, - group=group, - qid=qid, - label_lower_bound=label_lower_bound, - label_upper_bound=label_upper_bound, feature_names=feature_names, feature_types=feature_types, - feature_weights=feature_weights + **kwargs, ) try: # Differ the exception in order to return 0 and stop the iteration. @@ -558,7 +564,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes feature_types=None, feature_weights=None ) -> None: - """Set meta info for DMatrix. See doc string for DMatrix constructor.""" + """Set meta info for DMatrix. See doc string for :py:obj:`xgboost.DMatrix`.""" from .data import dispatch_meta_backend if label is not None: @@ -959,76 +965,14 @@ class DMatrix: # pylint: disable=too-many-instance-attributes c_bst_ulong(0))) -class DeviceQuantileDMatrix(DMatrix): - """Device memory Data Matrix used in XGBoost for training with - tree_method='gpu_hist'. Do not use this for test/validation tasks as some - information may be lost in quantisation. This DMatrix is primarily designed - to save memory in training from device memory inputs by avoiding - intermediate storage. Set max_bin to control the number of bins during - quantisation. See doc string in `DMatrix` for documents on meta info. +class _ProxyDMatrix(DMatrix): + """A placeholder class when DMatrix cannot be constructed (DeviceQuantileDMatrix, + inplace_predict). - You can construct DeviceQuantileDMatrix from cupy/cudf/dlpack. - - .. versionadded:: 1.1.0 """ - @_deprecate_positional_args - def __init__( # pylint: disable=super-init-not-called - self, - data, - label=None, - *, - weight=None, - base_margin=None, - missing=None, - silent=False, - feature_names=None, - feature_types=None, - nthread: Optional[int] = None, - max_bin: int = 256, - group=None, - qid=None, - label_lower_bound=None, - label_upper_bound=None, - feature_weights=None, - enable_categorical: bool = False, - ): - self.max_bin = max_bin - self.missing = missing if missing is not None else np.nan - self.nthread = nthread if nthread is not None else 1 - self._silent = silent # unused, kept for compatibility - - if isinstance(data, ctypes.c_void_p): - self.handle = data - return - from .data import init_device_quantile_dmatrix - handle, feature_names, feature_types = init_device_quantile_dmatrix( - data, - label=label, weight=weight, - base_margin=base_margin, - group=group, - qid=qid, - missing=self.missing, - label_lower_bound=label_lower_bound, - label_upper_bound=label_upper_bound, - feature_weights=feature_weights, - feature_names=feature_names, - feature_types=feature_types, - threads=self.nthread, - max_bin=self.max_bin, - ) - if enable_categorical: - raise NotImplementedError( - 'categorical support is not enabled on DeviceQuantileDMatrix.' - ) - self.handle = handle - if qid is not None and group is not None: - raise ValueError( - 'Only one of the eval_qid or eval_group for each evaluation ' - 'dataset should be provided.' - ) - - self.feature_names = feature_names - self.feature_types = feature_types + def __init__(self): # pylint: disable=super-init-not-called + self.handle = ctypes.c_void_p() + _check_call(_LIB.XGProxyDMatrixCreate(ctypes.byref(self.handle))) def _set_data_from_cuda_interface(self, data): '''Set data from CUDA array interface.''' @@ -1053,6 +997,116 @@ class DeviceQuantileDMatrix(DMatrix): ) +class DeviceQuantileDMatrix(DMatrix): + """Device memory Data Matrix used in XGBoost for training with tree_method='gpu_hist'. Do + not use this for test/validation tasks as some information may be lost in + quantisation. This DMatrix is primarily designed to save memory in training from + device memory inputs by avoiding intermediate storage. Set max_bin to control the + number of bins during quantisation. See doc string in :py:obj:`xgboost.DMatrix` for + documents on meta info. + + You can construct DeviceQuantileDMatrix from cupy/cudf/dlpack. + + .. versionadded:: 1.1.0 + + """ + + @_deprecate_positional_args + def __init__( # pylint: disable=super-init-not-called + self, + data, + label=None, + *, + weight=None, + base_margin=None, + missing=None, + silent=False, + feature_names=None, + feature_types=None, + nthread: Optional[int] = None, + max_bin: int = 256, + group=None, + qid=None, + label_lower_bound=None, + label_upper_bound=None, + feature_weights=None, + enable_categorical: bool = False, + ): + self.max_bin = max_bin + self.missing = missing if missing is not None else np.nan + self.nthread = nthread if nthread is not None else 1 + self._silent = silent # unused, kept for compatibility + + if isinstance(data, ctypes.c_void_p): + self.handle = data + return + + if enable_categorical: + raise NotImplementedError( + 'categorical support is not enabled on DeviceQuantileDMatrix.' + ) + if qid is not None and group is not None: + raise ValueError( + 'Only one of the eval_qid or eval_group for each evaluation ' + 'dataset should be provided.' + ) + + self._init( + data, + label=label, + weight=weight, + base_margin=base_margin, + group=group, + qid=qid, + label_lower_bound=label_lower_bound, + label_upper_bound=label_upper_bound, + feature_weights=feature_weights, + feature_names=feature_names, + feature_types=feature_types, + ) + + def _init(self, data, feature_names, feature_types, **meta): + from .data import ( + _is_dlpack, + _transform_dlpack, + _is_iter, + SingleBatchInternalIter, + ) + + if _is_dlpack(data): + # We specialize for dlpack because cupy will take the memory from it so + # it can't be transformed twice. + data = _transform_dlpack(data) + if _is_iter(data): + it = data + else: + it = SingleBatchInternalIter( + data, **meta, feature_names=feature_names, feature_types=feature_types + ) + + reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(it.reset_wrapper) + next_callback = ctypes.CFUNCTYPE( + ctypes.c_int, + ctypes.c_void_p, + )(it.next_wrapper) + handle = ctypes.c_void_p() + ret = _LIB.XGDeviceQuantileDMatrixCreateFromCallback( + None, + it.proxy.handle, + reset_callback, + next_callback, + ctypes.c_float(self.missing), + ctypes.c_int(self.nthread), + ctypes.c_int(self.max_bin), + ctypes.byref(handle), + ) + if it.exception: + raise it.exception + # delay check_call to throw intermediate exception first + _check_call(ret) + self.handle = handle + + Objective = Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]] Metric = Callable[[np.ndarray, DMatrix], Tuple[str, float]] @@ -1346,7 +1400,7 @@ class Booster(object): def boost(self, dtrain, grad, hess): """Boost the booster for one iteration, with customized gradient - statistics. Like :func:`xgboost.core.Booster.update`, this + statistics. Like :py:func:`xgboost.Booster.update`, this function should not be called directly by users. Parameters @@ -1360,7 +1414,9 @@ class Booster(object): """ if len(grad) != len(hess): - raise ValueError('grad / hess length mismatch: {} / {}'.format(len(grad), len(hess))) + raise ValueError( + 'grad / hess length mismatch: {} / {}'.format(len(grad), len(hess)) + ) if not isinstance(dtrain, DMatrix): raise TypeError('invalid training matrix: {}'.format(type(dtrain).__name__)) self._validate_features(dtrain) @@ -1453,17 +1509,12 @@ class Booster(object): training=False): """Predict with data. - .. note:: This function is not thread safe except for ``gbtree`` - booster. + .. note:: This function is not thread safe except for ``gbtree`` booster. - For ``gbtree`` booster, the thread safety is guaranteed by locks. - For lock free prediction use ``inplace_predict`` instead. Also, the - safety does not hold when used in conjunction with other methods. - - When using booster other than ``gbtree``, predict can only be called - from one thread. If you want to run prediction using multiple - thread, call ``bst.copy()`` to make copies of model object and then - call ``predict()``. + When using booster other than ``gbtree``, predict can only be called from one + thread. If you want to run prediction using multiple thread, call + :py:meth:`xgboost.Booster.copy` to make copies of model object and then call + ``predict()``. Parameters ---------- @@ -1579,9 +1630,17 @@ class Booster(object): preds = preds.reshape(nrow, chunk_size) return preds - def inplace_predict(self, data, iteration_range=(0, 0), - predict_type='value', missing=np.nan): - '''Run prediction in-place, Unlike ``predict`` method, inplace prediction does + def inplace_predict( + self, + data, + iteration_range: Tuple[int, int] = (0, 0), + predict_type: str = "value", + missing: float = np.nan, + validate_features: bool = True, + base_margin: Any = None, + strict_shape: bool = False + ): + """Run prediction in-place, Unlike ``predict`` method, inplace prediction does not cache the prediction result. Calling only ``inplace_predict`` in multiple threads is safe and lock @@ -1617,6 +1676,15 @@ class Booster(object): missing : float Value in the input data which needs to be present as a missing value. + validate_features: + See :py:meth:`xgboost.Booster.predict` for details. + base_margin: + See :py:obj:`xgboost.DMatrix` for details. + strict_shape: + When set to True, output shape is invariant to whether classification is used. + For both value and margin prediction, the output shape is (n_samples, + n_groups), n_groups == 1 when multi-class is not used. Default to False, in + which case the output shape can be (n_samples, ) if multi-class is not used. Returns ------- @@ -1624,107 +1692,117 @@ class Booster(object): The prediction result. When input data is on GPU, prediction result is stored in a cupy array. - ''' - - def reshape_output(predt, rows): - '''Reshape for multi-output prediction.''' - if predt.size != rows and predt.size % rows == 0: - cols = int(predt.size / rows) - predt = predt.reshape(rows, cols) - return predt - return predt - - length = c_bst_ulong() + """ preds = ctypes.POINTER(ctypes.c_float)() - iteration_range = (ctypes.c_uint(iteration_range[0]), - ctypes.c_uint(iteration_range[1])) # once caching is supported, we can pass id(data) as cache id. try: import pandas as pd + if isinstance(data, pd.DataFrame): data = data.values except ImportError: pass + args = { + "type": 0, + "training": False, + "iteration_begin": iteration_range[0], + "iteration_end": iteration_range[1], + "missing": missing, + "strict_shape": strict_shape, + "cache_id": 0, + } + if predict_type == "margin": + args["type"] = 1 + shape = ctypes.POINTER(c_bst_ulong)() + dims = c_bst_ulong() + + if base_margin is not None: + proxy = _ProxyDMatrix() + proxy.set_info(base_margin=base_margin) + p_handle = proxy.handle + else: + proxy = None + p_handle = ctypes.c_void_p() + assert proxy is None or isinstance(proxy, _ProxyDMatrix) + if validate_features: + if len(data.shape) != 1 and self.num_features() != data.shape[1]: + raise ValueError( + f"Feature shape mismatch, expected: {self.num_features()}, " + f"got {data.shape[0]}" + ) + if isinstance(data, np.ndarray): - assert data.flags.c_contiguous - arr = np.array(data.reshape(data.size), copy=False, - dtype=np.float32) - _check_call(_LIB.XGBoosterPredictFromDense( - self.handle, - arr.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), - c_bst_ulong(data.shape[0]), - c_bst_ulong(data.shape[1]), - ctypes.c_float(missing), - iteration_range[0], - iteration_range[1], - c_str(predict_type), - c_bst_ulong(0), - ctypes.byref(length), - ctypes.byref(preds) - )) - preds = ctypes2numpy(preds, length.value, np.float32) - rows = data.shape[0] - return reshape_output(preds, rows) + from .data import _maybe_np_slice + data = _maybe_np_slice(data, data.dtype) + _check_call( + _LIB.XGBoosterPredictFromDense( + self.handle, + _array_interface(data), + from_pystr_to_cstr(json.dumps(args)), + p_handle, + ctypes.byref(shape), + ctypes.byref(dims), + ctypes.byref(preds), + ) + ) + return _prediction_output(shape, dims, preds, False) if isinstance(data, scipy.sparse.csr_matrix): csr = data - _check_call(_LIB.XGBoosterPredictFromCSR( - self.handle, - c_array(ctypes.c_size_t, csr.indptr), - c_array(ctypes.c_uint, csr.indices), - c_array(ctypes.c_float, csr.data), - ctypes.c_size_t(len(csr.indptr)), - ctypes.c_size_t(len(csr.data)), - ctypes.c_size_t(csr.shape[1]), - ctypes.c_float(missing), - iteration_range[0], - iteration_range[1], - c_str(predict_type), - c_bst_ulong(0), - ctypes.byref(length), - ctypes.byref(preds))) - preds = ctypes2numpy(preds, length.value, np.float32) - rows = data.shape[0] - return reshape_output(preds, rows) - if lazy_isinstance(data, 'cupy.core.core', 'ndarray'): - assert data.flags.c_contiguous + _check_call( + _LIB.XGBoosterPredictFromCSR( + self.handle, + _array_interface(csr.indptr), + _array_interface(csr.indices), + _array_interface(csr.data), + ctypes.c_size_t(csr.shape[1]), + from_pystr_to_cstr(json.dumps(args)), + p_handle, + ctypes.byref(shape), + ctypes.byref(dims), + ctypes.byref(preds), + ) + ) + return _prediction_output(shape, dims, preds, False) + if lazy_isinstance(data, "cupy.core.core", "ndarray"): + from .data import _transform_cupy_array + data = _transform_cupy_array(data) interface = data.__cuda_array_interface__ - if 'mask' in interface: - interface['mask'] = interface['mask'].__cuda_array_interface__ - interface_str = bytes(json.dumps(interface, indent=2), 'utf-8') - _check_call(_LIB.XGBoosterPredictFromArrayInterface( - self.handle, - interface_str, - ctypes.c_float(missing), - iteration_range[0], - iteration_range[1], - c_str(predict_type), - c_bst_ulong(0), - ctypes.byref(length), - ctypes.byref(preds))) - mem = ctypes2cupy(preds, length, np.float32) - rows = data.shape[0] - return reshape_output(mem, rows) - if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'): + if "mask" in interface: + interface["mask"] = interface["mask"].__cuda_array_interface__ + interface_str = bytes(json.dumps(interface, indent=2), "utf-8") + _check_call( + _LIB.XGBoosterPredictFromArrayInterface( + self.handle, + interface_str, + from_pystr_to_cstr(json.dumps(args)), + p_handle, + ctypes.byref(shape), + ctypes.byref(dims), + ctypes.byref(preds), + ) + ) + return _prediction_output(shape, dims, preds, True) + if lazy_isinstance(data, "cudf.core.dataframe", "DataFrame"): from .data import _cudf_array_interfaces - interfaces_str = _cudf_array_interfaces(data) - _check_call(_LIB.XGBoosterPredictFromArrayInterfaceColumns( - self.handle, - interfaces_str, - ctypes.c_float(missing), - iteration_range[0], - iteration_range[1], - c_str(predict_type), - c_bst_ulong(0), - ctypes.byref(length), - ctypes.byref(preds))) - mem = ctypes2cupy(preds, length, np.float32) - rows = data.shape[0] - predt = reshape_output(mem, rows) - return predt - raise TypeError('Data type:' + str(type(data)) + - ' not supported by inplace prediction.') + interfaces_str = _cudf_array_interfaces(data) + _check_call( + _LIB.XGBoosterPredictFromArrayInterfaceColumns( + self.handle, + interfaces_str, + from_pystr_to_cstr(json.dumps(args)), + p_handle, + ctypes.byref(shape), + ctypes.byref(dims), + ctypes.byref(preds), + ) + ) + return _prediction_output(shape, dims, preds, True) + + raise TypeError( + "Data type:" + str(type(data)) + " not supported by inplace prediction." + ) def save_model(self, fname): """Save the model to a file. diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 1db498045..c75137a61 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -187,8 +187,8 @@ class DaskDMatrix: `DaskDMatrix` forces all lazy computation to be carried out. Wait for the input data explicitly if you want to see actual computation of constructing `DaskDMatrix`. - See doc string for DMatrix constructor for other parameters. DaskDMatrix accepts only - dask collection. + See doc for :py:obj:`xgboost.DMatrix` constructor for other parameters. DaskDMatrix + accepts only dask collection. .. note:: @@ -575,7 +575,8 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix): memory usage by eliminating data copies. Internally the all partitions/chunks of data are merged by weighted GK sketching. So the number of partitions from dask may affect training accuracy as GK generates bounded error for each merge. See doc string for - `DeviceQuantileDMatrix` and `DMatrix` for other parameters. + :py:obj:`xgboost.DeviceQuantileDMatrix` and :py:obj:`xgboost.DMatrix` for other + parameters. .. versionadded:: 1.2.0 diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 555d066f6..19041a879 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -5,11 +5,12 @@ import ctypes import json import warnings import os +from typing import Any import numpy as np from .core import c_array, _LIB, _check_call, c_str -from .core import DataIter, DeviceQuantileDMatrix, DMatrix +from .core import DataIter, _ProxyDMatrix, DMatrix from .compat import lazy_isinstance c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name @@ -113,7 +114,7 @@ def _maybe_np_slice(data, dtype): return data -def _transform_np_array(data: np.ndarray): +def _transform_np_array(data: np.ndarray) -> np.ndarray: if not isinstance(data, np.ndarray) and hasattr(data, '__array__'): data = np.array(data, copy=False) if len(data.shape) != 2: @@ -142,7 +143,7 @@ def _from_numpy_array(data, missing, nthread, feature_names, feature_types): input layout and type if memory use is a concern. """ - flatten = _transform_np_array(data) + flatten: np.ndarray = _transform_np_array(data) handle = ctypes.c_void_p() _check_call(_LIB.XGDMatrixCreateFromMat_omp( flatten.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), @@ -783,54 +784,6 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902 self.it = 0 -def init_device_quantile_dmatrix( - data, missing, max_bin, threads, feature_names, feature_types, **meta -): - '''Constructor for DeviceQuantileDMatrix.''' - if not any([_is_cudf_df(data), _is_cudf_ser(data), _is_cupy_array(data), - _is_dlpack(data), _is_iter(data)]): - raise TypeError(str(type(data)) + - ' is not supported for DeviceQuantileDMatrix') - if _is_dlpack(data): - # We specialize for dlpack because cupy will take the memory from it so - # it can't be transformed twice. - data = _transform_dlpack(data) - if _is_iter(data): - it = data - else: - it = SingleBatchInternalIter( - data, **meta, feature_names=feature_names, - feature_types=feature_types) - - reset_factory = ctypes.CFUNCTYPE(None, ctypes.c_void_p) - reset_callback = reset_factory(it.reset_wrapper) - next_factory = ctypes.CFUNCTYPE( - ctypes.c_int, - ctypes.c_void_p, - ) - next_callback = next_factory(it.next_wrapper) - handle = ctypes.c_void_p() - ret = _LIB.XGDeviceQuantileDMatrixCreateFromCallback( - None, - it.proxy.handle, - reset_callback, - next_callback, - ctypes.c_float(missing), - ctypes.c_int(threads), - ctypes.c_int(max_bin), - ctypes.byref(handle) - ) - if it.exception: - raise it.exception - # delay check_call to throw intermediate exception first - _check_call(ret) - matrix = DeviceQuantileDMatrix(handle) - feature_names = matrix.feature_names - feature_types = matrix.feature_types - matrix.handle = None - return handle, feature_names, feature_types - - def _device_quantile_transform(data, feature_names, feature_types): if _is_cudf_df(data): return _transform_cudf_df(data, feature_names, feature_types) @@ -845,7 +798,7 @@ def _device_quantile_transform(data, feature_names, feature_types): str(type(data))) -def dispatch_device_quantile_dmatrix_set_data(proxy, data): +def dispatch_device_quantile_dmatrix_set_data(proxy: _ProxyDMatrix, data: Any) -> None: '''Dispatch for DeviceQuantileDMatrix.''' if _is_cudf_df(data): proxy._set_data_from_cuda_columnar(data) # pylint: disable=W0212 diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 8af8b2ab9..6a6728276 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -21,6 +21,7 @@ #include "xgboost/global_config.h" #include "c_api_error.h" +#include "c_api_utils.h" #include "../common/io.h" #include "../common/charconv.h" #include "../data/adapter.h" @@ -617,90 +618,92 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle, API_END(); } + +template +void InplacePredictImpl(std::shared_ptr x, std::shared_ptr p_m, + char const *c_json_config, Learner *learner, + size_t n_rows, size_t n_cols, + xgboost::bst_ulong const **out_shape, + xgboost::bst_ulong *out_dim, const float **out_result) { + auto config = Json::Load(StringView{c_json_config}); + CHECK_EQ(get(config["cache_id"]), 0) << "Cache ID is not supported yet"; + + HostDeviceVector* p_predt { nullptr }; + auto type = PredictionType(get(config["type"])); + learner->InplacePredict(x, p_m, type, get(config["missing"]), + &p_predt, + get(config["iteration_begin"]), + get(config["iteration_end"])); + CHECK(p_predt); + auto &shape = learner->GetThreadLocal().prediction_shape; + auto chunksize = n_rows == 0 ? 0 : p_predt->Size() / n_rows; + bool strict_shape = get(config["strict_shape"]); + CalcPredictShape(strict_shape, type, n_rows, n_cols, chunksize, learner->Groups(), + learner->BoostedRounds(), &shape, out_dim); + *out_result = dmlc::BeginPtr(p_predt->HostVector()); + *out_shape = dmlc::BeginPtr(shape); +} + // A hidden API as cache id is not being supported yet. -XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, float *values, - xgboost::bst_ulong n_rows, - xgboost::bst_ulong n_cols, - float missing, - unsigned iteration_begin, - unsigned iteration_end, - char const* c_type, - xgboost::bst_ulong cache_id, - xgboost::bst_ulong *out_len, +XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, + char const *array_interface, + char const *c_json_config, + DMatrixHandle m, + xgboost::bst_ulong const **out_shape, + xgboost::bst_ulong *out_dim, const float **out_result) { API_BEGIN(); CHECK_HANDLE(); - CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet"; + std::shared_ptr x{ + new xgboost::data::ArrayAdapter(StringView{array_interface})}; + std::shared_ptr p_m {nullptr}; + if (m) { + p_m = *static_cast *>(m); + } auto *learner = static_cast(handle); - - std::shared_ptr x{ - new xgboost::data::DenseAdapter(values, n_rows, n_cols)}; - HostDeviceVector* p_predt { nullptr }; - std::string type { c_type }; - learner->InplacePredict(x, type, missing, &p_predt, iteration_begin, iteration_end); - CHECK(p_predt); - - *out_result = dmlc::BeginPtr(p_predt->HostVector()); - *out_len = static_cast(p_predt->Size()); + InplacePredictImpl(x, p_m, c_json_config, learner, x->NumRows(), + x->NumColumns(), out_shape, out_dim, out_result); API_END(); } // A hidden API as cache id is not being supported yet. -XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, - const size_t* indptr, - const unsigned* indices, - const bst_float* data, - size_t nindptr, - size_t nelem, - size_t num_col, - float missing, - unsigned iteration_begin, - unsigned iteration_end, - char const *c_type, - xgboost::bst_ulong cache_id, - xgboost::bst_ulong *out_len, +XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, char const *indptr, + char const *indices, char const *data, + xgboost::bst_ulong cols, + char const *c_json_config, DMatrixHandle m, + xgboost::bst_ulong const **out_shape, + xgboost::bst_ulong *out_dim, const float **out_result) { API_BEGIN(); CHECK_HANDLE(); - CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet"; + std::shared_ptr x{ + new xgboost::data::CSRArrayAdapter{ + StringView{indptr}, StringView{indices}, StringView{data}, cols}}; + std::shared_ptr p_m {nullptr}; + if (m) { + p_m = *static_cast *>(m); + } auto *learner = static_cast(handle); - - std::shared_ptr x{ - new xgboost::data::CSRAdapter(indptr, indices, data, nindptr - 1, nelem, num_col)}; - HostDeviceVector* p_predt { nullptr }; - std::string type { c_type }; - learner->InplacePredict(x, type, missing, &p_predt, iteration_begin, iteration_end); - CHECK(p_predt); - - *out_result = dmlc::BeginPtr(p_predt->HostVector()); - *out_len = static_cast(p_predt->Size()); + InplacePredictImpl(x, p_m, c_json_config, learner, x->NumRows(), + x->NumColumns(), out_shape, out_dim, out_result); API_END(); } #if !defined(XGBOOST_USE_CUDA) -XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(BoosterHandle handle, - char const* c_json_strs, - float missing, - unsigned iteration_begin, - unsigned iteration_end, - char const* c_type, - xgboost::bst_ulong cache_id, - xgboost::bst_ulong *out_len, - float const** out_result) { +XGB_DLL int XGBoosterPredictFromArrayInterface( + BoosterHandle handle, char const *c_json_strs, char const *c_json_config, + DMatrixHandle m, xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim, + const float **out_result) { API_BEGIN(); CHECK_HANDLE(); common::AssertGPUSupport(); API_END(); } -XGB_DLL int XGBoosterPredictFromArrayInterface(BoosterHandle handle, - char const* c_json_strs, - float missing, - unsigned iteration_begin, - unsigned iteration_end, - char const* c_type, - xgboost::bst_ulong cache_id, - xgboost::bst_ulong *out_len, - const float **out_result) { + +XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns( + BoosterHandle handle, char const *c_json_strs, char const *c_json_config, + DMatrixHandle m, xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim, + const float **out_result) { API_BEGIN(); CHECK_HANDLE(); common::AssertGPUSupport(); diff --git a/src/c_api/c_api.cu b/src/c_api/c_api.cu index ecebb0e98..baa4d8fdf 100644 --- a/src/c_api/c_api.cu +++ b/src/c_api/c_api.cu @@ -1,8 +1,9 @@ -// Copyright (c) 2019-2020 by Contributors +// Copyright (c) 2019-2021 by Contributors #include "xgboost/data.h" #include "xgboost/c_api.h" #include "xgboost/learner.h" #include "c_api_error.h" +#include "c_api_utils.h" #include "../data/device_adapter.cuh" using namespace xgboost; // NOLINT @@ -30,59 +31,63 @@ XGB_DLL int XGDMatrixCreateFromArrayInterface(char const* c_json_strs, API_END(); } -// A hidden API as cache id is not being supported yet. -XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(BoosterHandle handle, - char const* c_json_strs, - float missing, - unsigned iteration_begin, - unsigned iteration_end, - char const* c_type, - xgboost::bst_ulong cache_id, - xgboost::bst_ulong *out_len, - float const** out_result) { +template +int InplacePreidctCuda(BoosterHandle handle, char const *c_json_strs, + char const *c_json_config, + std::shared_ptr p_m, + xgboost::bst_ulong const **out_shape, + xgboost::bst_ulong *out_dim, const float **out_result) { API_BEGIN(); CHECK_HANDLE(); - CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet"; - auto *learner = static_cast(handle); + auto config = Json::Load(StringView{c_json_config}); + CHECK_EQ(get(config["cache_id"]), 0) + << "Cache ID is not supported yet"; + auto *learner = static_cast(handle); std::string json_str{c_json_strs}; - auto x = std::make_shared(json_str); - HostDeviceVector* p_predt { nullptr }; - std::string type { c_type }; - learner->InplacePredict(x, type, missing, &p_predt, iteration_begin, iteration_end); + auto x = std::make_shared(json_str); + HostDeviceVector *p_predt{nullptr}; + auto type = PredictionType(get(config["type"])); + learner->InplacePredict(x, p_m, type, get(config["missing"]), + &p_predt, + get(config["iteration_begin"]), + get(config["iteration_end"])); CHECK(p_predt); - CHECK(p_predt->DeviceCanRead()); + CHECK(p_predt->DeviceCanRead() && !p_predt->HostCanRead()); + auto &shape = learner->GetThreadLocal().prediction_shape; + auto chunksize = x->NumRows() == 0 ? 0 : p_predt->Size() / x->NumRows(); + bool strict_shape = get(config["strict_shape"]); + CalcPredictShape(strict_shape, type, x->NumRows(), x->NumColumns(), chunksize, + learner->Groups(), learner->BoostedRounds(), &shape, + out_dim); + *out_shape = dmlc::BeginPtr(shape); *out_result = p_predt->ConstDevicePointer(); - *out_len = static_cast(p_predt->Size()); - API_END(); } + // A hidden API as cache id is not being supported yet. -XGB_DLL int XGBoosterPredictFromArrayInterface(BoosterHandle handle, - char const* c_json_strs, - float missing, - unsigned iteration_begin, - unsigned iteration_end, - char const* c_type, - xgboost::bst_ulong cache_id, - xgboost::bst_ulong *out_len, - float const** out_result) { - API_BEGIN(); - CHECK_HANDLE(); - CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet"; - auto *learner = static_cast(handle); - - std::string json_str{c_json_strs}; - auto x = std::make_shared(json_str); - HostDeviceVector* p_predt { nullptr }; - std::string type { c_type }; - learner->InplacePredict(x, type, missing, &p_predt, iteration_begin, iteration_end); - CHECK(p_predt); - CHECK(p_predt->DeviceCanRead()); - - *out_result = p_predt->ConstDevicePointer(); - *out_len = static_cast(p_predt->Size()); - - API_END(); +XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns( + BoosterHandle handle, char const *c_json_strs, char const *c_json_config, + DMatrixHandle m, xgboost::bst_ulong const **out_shape, + xgboost::bst_ulong *out_dim, const float **out_result) { + std::shared_ptr p_m {nullptr}; + if (m) { + p_m = *static_cast *>(m); + } + return InplacePreidctCuda( + handle, c_json_strs, c_json_config, p_m, out_shape, out_dim, out_result); +} + +// A hidden API as cache id is not being supported yet. +XGB_DLL int XGBoosterPredictFromArrayInterface( + BoosterHandle handle, char const *c_json_strs, char const *c_json_config, + DMatrixHandle m, xgboost::bst_ulong const **out_shape, + xgboost::bst_ulong *out_dim, const float **out_result) { + std::shared_ptr p_m {nullptr}; + if (m) { + p_m = *static_cast *>(m); + } + return InplacePreidctCuda( + handle, c_json_strs, c_json_config, p_m, out_shape, out_dim, out_result); } diff --git a/src/c_api/c_api_utils.h b/src/c_api/c_api_utils.h new file mode 100644 index 000000000..0ebc0b8be --- /dev/null +++ b/src/c_api/c_api_utils.h @@ -0,0 +1,114 @@ +/*! + * Copyright (c) 2021 by XGBoost Contributors + */ +#ifndef XGBOOST_C_API_C_API_UTILS_H_ +#define XGBOOST_C_API_C_API_UTILS_H_ + +#include +#include +#include + +#include "xgboost/logging.h" +#include "xgboost/learner.h" + +namespace xgboost { +/* \brief Determine the output shape of prediction. + * + * \param strict_shape Whether should we reshape the output with consideration of groups + * and forest. + * \param type Prediction type + * \param rows Input samples + * \param cols Input features + * \param chunksize Total elements of output / rows + * \param groups Number of output groups from Learner + * \param rounds end_iteration - beg_iteration + * \param out_shape Output shape + * \param out_dim Output dimension + */ +inline void CalcPredictShape(bool strict_shape, PredictionType type, size_t rows, size_t cols, + size_t chunksize, size_t groups, size_t rounds, + std::vector *out_shape, + xgboost::bst_ulong *out_dim) { + auto &shape = *out_shape; + if ((type == PredictionType::kMargin || type == PredictionType::kValue) && + rows != 0) { + CHECK_EQ(chunksize, groups); + } + + switch (type) { + case PredictionType::kValue: + case PredictionType::kMargin: { + if (chunksize == 1 && !strict_shape) { + *out_dim = 1; + shape.resize(*out_dim); + shape.front() = rows; + } else { + *out_dim = 2; + shape.resize(*out_dim); + shape.front() = rows; + shape.back() = groups; + } + break; + } + case PredictionType::kApproxContribution: + case PredictionType::kContribution: { + auto groups = chunksize / (cols + 1); + if (groups == 1 && !strict_shape) { + *out_dim = 2; + shape.resize(*out_dim); + shape.front() = rows; + shape.back() = cols + 1; + } else { + *out_dim = 3; + shape.resize(*out_dim); + shape[0] = rows; + shape[1] = groups; + shape[2] = cols + 1; + } + break; + } + case PredictionType::kInteraction: { + if (groups == 1 && !strict_shape) { + *out_dim = 3; + shape.resize(*out_dim); + shape[0] = rows; + shape[1] = cols + 1; + shape[2] = cols + 1; + } else { + *out_dim = 4; + shape.resize(*out_dim); + shape[0] = rows; + shape[1] = groups; + shape[2] = cols + 1; + shape[3] = cols + 1; + } + break; + } + case PredictionType::kLeaf: { + if (strict_shape) { + shape.resize(4); + shape[0] = rows; + shape[1] = rounds; + shape[2] = groups; + auto forest = chunksize / (shape[1] * shape[2]); + forest = std::max(static_cast(1), forest); + shape[3] = forest; + *out_dim = shape.size(); + } else { + *out_dim = 2; + shape.resize(*out_dim); + shape.front() = rows; + shape.back() = chunksize; + } + break; + } + default: { + LOG(FATAL) << "Unknown prediction type:" << static_cast(type); + } + } + CHECK_EQ( + std::accumulate(shape.cbegin(), shape.cend(), 1, std::multiplies<>{}), + chunksize * rows); +} +} // namespace xgboost +#endif // XGBOOST_C_API_C_API_UTILS_H_ diff --git a/src/data/adapter.h b/src/data/adapter.h index 4d7c924c3..856c1219a 100644 --- a/src/data/adapter.h +++ b/src/data/adapter.h @@ -1,5 +1,5 @@ /*! - * Copyright (c) 2019~2020 by Contributors + * Copyright (c) 2019~2021 by Contributors * \file adapter.h */ #ifndef XGBOOST_DATA_ADAPTER_H_ @@ -228,6 +228,128 @@ class DenseAdapter : public detail::SingleBatchDataIter { size_t num_columns_; }; +class ArrayAdapterBatch : public detail::NoMetaInfo { + ArrayInterface array_interface_; + + class Line { + ArrayInterface array_interface_; + size_t ridx_; + public: + Line(ArrayInterface array_interface, size_t ridx) + : array_interface_{std::move(array_interface)}, ridx_{ridx} {} + + size_t Size() const { return array_interface_.num_cols; } + + COOTuple GetElement(size_t idx) const { + return {ridx_, idx, array_interface_.GetElement(idx)}; + } + }; + + public: + ArrayAdapterBatch() = default; + Line const GetLine(size_t idx) const { + auto line = array_interface_.SliceRow(idx); + return Line{line, idx}; + } + + explicit ArrayAdapterBatch(ArrayInterface array_interface) + : array_interface_{std::move(array_interface)} {} +}; + +/** + * Adapter for dense array on host, in Python that's `numpy.ndarray`. This is similar to + * `DenseAdapter`, but supports __array_interface__ instead of raw pointers. An + * advantage is this can handle various data type without making a copy. + */ +class ArrayAdapter : public detail::SingleBatchDataIter { + public: + explicit ArrayAdapter(StringView array_interface) { + auto j = Json::Load(array_interface); + array_interface_ = ArrayInterface(get(j)); + batch_ = ArrayAdapterBatch{array_interface_}; + } + ArrayAdapterBatch const& Value() const override { return batch_; } + size_t NumRows() const { return array_interface_.num_rows; } + size_t NumColumns() const { return array_interface_.num_cols; } + + private: + ArrayAdapterBatch batch_; + ArrayInterface array_interface_; +}; + +class CSRArrayAdapterBatch : public detail::NoMetaInfo { + ArrayInterface indptr_; + ArrayInterface indices_; + ArrayInterface values_; + + class Line { + ArrayInterface indices_; + ArrayInterface values_; + size_t ridx_; + + public: + Line(ArrayInterface indices, ArrayInterface values, size_t ridx) + : indices_{std::move(indices)}, values_{std::move(values)}, ridx_{ridx} {} + + COOTuple GetElement(size_t idx) const { + return {ridx_, indices_.GetElement(idx), values_.GetElement(idx)}; + } + size_t Size() const { + return values_.num_rows * values_.num_cols; + } + }; + + public: + CSRArrayAdapterBatch() = default; + CSRArrayAdapterBatch(ArrayInterface indptr, ArrayInterface indices, + ArrayInterface values) + : indptr_{std::move(indptr)}, indices_{std::move(indices)}, + values_{std::move(values)} {} + + Line const GetLine(size_t idx) const { + auto begin_offset = indptr_.GetElement(idx); + auto end_offset = indptr_.GetElement(idx + 1); + auto indices = indices_.SliceOffset(begin_offset); + auto values = values_.SliceOffset(begin_offset); + values.num_cols = end_offset - begin_offset; + values.num_rows = 1; + indices.num_cols = values.num_cols; + indices.num_rows = values.num_rows; + return Line{indices, values, idx}; + } +}; + +/** + * Adapter for CSR array on host, in Python that's `scipy.sparse.csr_matrix`. This is + * similar to `CSRAdapter`, but supports __array_interface__ instead of raw pointers. An + * advantage is this can handle various data type without making a copy. + */ +class CSRArrayAdapter : public detail::SingleBatchDataIter { + public: + CSRArrayAdapter(StringView indptr, StringView indices, StringView values, + size_t num_cols) + : indptr_{indptr}, indices_{indices}, values_{values}, num_cols_{num_cols} { + batch_ = CSRArrayAdapterBatch{indptr_, indices_, values_}; + } + + CSRArrayAdapterBatch const& Value() const override { + return batch_; + } + size_t NumRows() const { + size_t size = indptr_.num_cols * indptr_.num_rows; + size = size == 0 ? 0 : size - 1; + return size; + } + size_t NumColumns() const { return num_cols_; } + + private: + CSRArrayAdapterBatch batch_; + ArrayInterface indptr_; + ArrayInterface indices_; + ArrayInterface values_; + size_t num_cols_; +}; + class CSCAdapterBatch : public detail::NoMetaInfo { public: CSCAdapterBatch(const size_t* col_ptr, const unsigned* row_idx, diff --git a/src/data/array_interface.h b/src/data/array_interface.h index 5539e16f0..d2dfdc66c 100644 --- a/src/data/array_interface.h +++ b/src/data/array_interface.h @@ -1,5 +1,5 @@ /*! - * Copyright 2019 by Contributors + * Copyright 2019-2021 by Contributors * \file array_interface.h * \brief View of __array_interface__ */ @@ -87,7 +87,7 @@ struct ArrayInterfaceErrors { } } - static std::string UnSupportedType(const char (&typestr)[3]) { + static std::string UnSupportedType(StringView typestr) { return TypeStr(typestr[1]) + " is not supported."; } }; @@ -210,6 +210,7 @@ class ArrayInterfaceHandler { static_cast(get(j_shape.at(1)))}; } } + template static common::Span ExtractData(std::map const& column) { Validate(column); @@ -257,16 +258,24 @@ class ArrayInterface { } auto typestr = get(column.at("typestr")); - type[0] = typestr.at(0); - type[1] = typestr.at(1); - type[2] = typestr.at(2); - this->CheckType(); + this->AssignType(StringView{typestr}); } + public: + enum Type : std::int8_t { kF4, kF8, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 }; + public: ArrayInterface() = default; - explicit ArrayInterface(std::string const& str, bool allow_mask = true) { - auto jinterface = Json::Load({str.c_str(), str.size()}); + explicit ArrayInterface(std::string const &str, bool allow_mask = true) + : ArrayInterface{StringView{str.c_str(), str.size()}, allow_mask} {} + + explicit ArrayInterface(std::map const &column, + bool allow_mask = true) { + this->Initialize(column, allow_mask); + } + + explicit ArrayInterface(StringView str, bool allow_mask = true) { + auto jinterface = Json::Load(str); if (IsA(jinterface)) { this->Initialize(get(jinterface), allow_mask); return; @@ -279,71 +288,114 @@ class ArrayInterface { } } - explicit ArrayInterface(std::map const &column, - bool allow_mask = true) { - this->Initialize(column, allow_mask); - } - - void CheckType() const { - if (type[1] == 'f' && type[2] == '4') { - return; - } else if (type[1] == 'f' && type[2] == '8') { - return; - } else if (type[1] == 'i' && type[2] == '1') { - return; - } else if (type[1] == 'i' && type[2] == '2') { - return; - } else if (type[1] == 'i' && type[2] == '4') { - return; - } else if (type[1] == 'i' && type[2] == '8') { - return; - } else if (type[1] == 'u' && type[2] == '1') { - return; - } else if (type[1] == 'u' && type[2] == '2') { - return; - } else if (type[1] == 'u' && type[2] == '4') { - return; - } else if (type[1] == 'u' && type[2] == '8') { - return; + void AssignType(StringView typestr) { + if (typestr[1] == 'f' && typestr[2] == '4') { + type = kF4; + } else if (typestr[1] == 'f' && typestr[2] == '8') { + type = kF8; + } else if (typestr[1] == 'i' && typestr[2] == '1') { + type = kI1; + } else if (typestr[1] == 'i' && typestr[2] == '2') { + type = kI2; + } else if (typestr[1] == 'i' && typestr[2] == '4') { + type = kI4; + } else if (typestr[1] == 'i' && typestr[2] == '8') { + type = kI8; + } else if (typestr[1] == 'u' && typestr[2] == '1') { + type = kU1; + } else if (typestr[1] == 'u' && typestr[2] == '2') { + type = kU2; + } else if (typestr[1] == 'u' && typestr[2] == '4') { + type = kU4; + } else if (typestr[1] == 'u' && typestr[2] == '8') { + type = kU8; } else { - LOG(FATAL) << ArrayInterfaceErrors::UnSupportedType(type); + LOG(FATAL) << ArrayInterfaceErrors::UnSupportedType(typestr); return; } } - XGBOOST_DEVICE float GetElement(size_t idx) const { + XGBOOST_DEVICE ArrayInterface SliceOffset(size_t offset) const { + void* p_values; + switch (type) { + case kF4: + p_values = reinterpret_cast(data) + offset; + break; + case kF8: + p_values = reinterpret_cast(data) + offset; + break; + case kI1: + p_values = reinterpret_cast(data) + offset; + break; + case kI2: + p_values = reinterpret_cast(data) + offset; + break; + case kI4: + p_values = reinterpret_cast(data) + offset; + break; + case kI8: + p_values = reinterpret_cast(data) + offset; + break; + case kU1: + p_values = reinterpret_cast(data) + offset; + break; + case kU2: + p_values = reinterpret_cast(data) + offset; + break; + case kU4: + p_values = reinterpret_cast(data) + offset; + break; + case kU8: + p_values = reinterpret_cast(data) + offset; + break; + } + ArrayInterface ret = *this; + ret.data = p_values; + return ret; + } + + XGBOOST_DEVICE ArrayInterface SliceRow(size_t idx) const { + size_t offset = idx * num_cols; + auto ret = this->SliceOffset(offset); + ret.num_rows = 1; + return ret; + } + + template + XGBOOST_DEVICE T GetElement(size_t idx) const { SPAN_CHECK(idx < num_cols * num_rows); - if (type[1] == 'f' && type[2] == '4') { + switch (type) { + case kF4: return reinterpret_cast(data)[idx]; - } else if (type[1] == 'f' && type[2] == '8') { + case kF8: return reinterpret_cast(data)[idx]; - } else if (type[1] == 'i' && type[2] == '1') { + case kI1: return reinterpret_cast(data)[idx]; - } else if (type[1] == 'i' && type[2] == '2') { + case kI2: return reinterpret_cast(data)[idx]; - } else if (type[1] == 'i' && type[2] == '4') { + case kI4: return reinterpret_cast(data)[idx]; - } else if (type[1] == 'i' && type[2] == '8') { + case kI8: return reinterpret_cast(data)[idx]; - } else if (type[1] == 'u' && type[2] == '1') { + case kU1: return reinterpret_cast(data)[idx]; - } else if (type[1] == 'u' && type[2] == '2') { + case kU2: return reinterpret_cast(data)[idx]; - } else if (type[1] == 'u' && type[2] == '4') { + case kU4: return reinterpret_cast(data)[idx]; - } else if (type[1] == 'u' && type[2] == '8') { + case kU8: return reinterpret_cast(data)[idx]; - } else { - SPAN_CHECK(false); - return 0; } + SPAN_CHECK(false); + return reinterpret_cast(data)[idx]; } RBitField8 valid; bst_row_t num_rows; bst_feature_t num_cols; void* data; - char type[3]; + + Type type; }; } // namespace xgboost diff --git a/src/data/data.cu b/src/data/data.cu index f5748e20b..fa1438340 100644 --- a/src/data/data.cu +++ b/src/data/data.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2019 by XGBoost Contributors + * Copyright 2019-2021 by XGBoost Contributors * * \file data.cu * \brief Handles setting metainfo from array interface. @@ -45,15 +45,15 @@ auto SetDeviceToPtr(void *ptr) { } // anonymous namespace void CopyGroupInfoImpl(ArrayInterface column, std::vector* out) { - CHECK(column.type[1] == 'i' || column.type[1] == 'u') - << "Expected integer metainfo"; + CHECK(column.type != ArrayInterface::kF4 && column.type != ArrayInterface::kF8) + << "Expected integer for group info."; auto ptr_device = SetDeviceToPtr(column.data); dh::TemporaryArray temp(column.num_rows); auto d_tmp = temp.data(); dh::LaunchN(ptr_device, column.num_rows, [=] __device__(size_t idx) { - d_tmp[idx] = column.GetElement(idx); + d_tmp[idx] = column.GetElement(idx); }); auto length = column.num_rows; out->resize(length + 1); @@ -103,15 +103,15 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) { auto it = dh::MakeTransformIterator( thrust::make_counting_iterator(0ul), [array_interface] __device__(size_t i) { - return static_cast(array_interface.GetElement(i)); + return array_interface.GetElement(i); }); dh::caching_device_vector flag(1); auto d_flag = dh::ToSpan(flag); auto d = SetDeviceToPtr(array_interface.data); dh::LaunchN(d, 1, [=] __device__(size_t) { d_flag[0] = true; }); dh::LaunchN(d, array_interface.num_rows - 1, [=] __device__(size_t i) { - if (static_cast(array_interface.GetElement(i)) > - static_cast(array_interface.GetElement(i + 1))) { + if (array_interface.GetElement(i) > + array_interface.GetElement(i + 1)) { d_flag[0] = false; } }); diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index 25c9809ca..059804e58 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -1,5 +1,5 @@ /*! - * Copyright 2014-2020 by Contributors + * Copyright 2014-2021 by Contributors * \file gbtree.cc * \brief gradient boosted tree implementation. * \author Tianqi Chen @@ -265,15 +265,34 @@ class GBTree : public GradientBooster { bool training, unsigned ntree_limit) override; - void InplacePredict(dmlc::any const &x, float missing, - PredictionCacheEntry *out_preds, - uint32_t layer_begin, - unsigned layer_end) const override { + void InplacePredict(dmlc::any const &x, std::shared_ptr p_m, + float missing, PredictionCacheEntry *out_preds, + uint32_t layer_begin, unsigned layer_end) const override { CHECK(configured_); uint32_t tree_begin, tree_end; - std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end); - this->GetPredictor()->InplacePredict(x, model_, missing, out_preds, - tree_begin, tree_end); + std::tie(tree_begin, tree_end) = + detail::LayerToTree(model_, tparam_, layer_begin, layer_end); + std::vector predictors{ + cpu_predictor_.get(), +#if defined(XGBOOST_USE_CUDA) + gpu_predictor_.get() +#endif // defined(XGBOOST_USE_CUDA) + }; + StringView msg{"Unsupported data type for inplace predict."}; + if (tparam_.predictor == PredictorType::kAuto) { + // Try both predictor implementations + for (auto const &p : predictors) { + if (p && p->InplacePredict(x, p_m, model_, missing, out_preds, + tree_begin, tree_end)) { + return; + } + } + LOG(FATAL) << msg; + } else { + bool success = this->GetPredictor()->InplacePredict( + x, p_m, model_, missing, out_preds, tree_begin, tree_end); + CHECK(success) << msg; + } } void PredictInstance(const SparsePage::Inst& inst, diff --git a/src/learner.cc b/src/learner.cc index cb0618295..b4b7f89d9 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2014-2020 by Contributors + * Copyright 2014-2021 by Contributors * \file learner.cc * \brief Implementation of learning algorithm. * \author Tianqi Chen @@ -1110,23 +1110,30 @@ class LearnerImpl : public LearnerIO { CHECK(!this->need_configuration_); return this->gbm_->BoostedRounds(); } + uint32_t Groups() const override { + CHECK(!this->need_configuration_); + return this->learner_model_param_.num_output_group; + } XGBAPIThreadLocalEntry& GetThreadLocal() const override { return (*LearnerAPIThreadLocalStore::Get())[this]; } - void InplacePredict(dmlc::any const &x, std::string const &type, - float missing, HostDeviceVector **out_preds, - uint32_t layer_begin, uint32_t layer_end) override { + void InplacePredict(dmlc::any const &x, std::shared_ptr p_m, + PredictionType type, float missing, + HostDeviceVector **out_preds, + uint32_t iteration_begin, + uint32_t iteration_end) override { this->Configure(); auto& out_predictions = this->GetThreadLocal().prediction_entry; - this->gbm_->InplacePredict(x, missing, &out_predictions, layer_begin, - layer_end); - if (type == "value") { + this->gbm_->InplacePredict(x, p_m, missing, &out_predictions, + iteration_begin, iteration_end); + if (type == PredictionType::kValue) { obj_->PredTransform(&out_predictions.predictions); - } else if (type == "margin") { + } else if (type == PredictionType::kMargin) { + // do nothing } else { - LOG(FATAL) << "Unsupported prediction type:" << type; + LOG(FATAL) << "Unsupported prediction type:" << static_cast(type); } *out_preds = &out_predictions.predictions; } diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 54a4427e6..9fdf925db 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -1,5 +1,5 @@ /*! - * Copyright by Contributors 2017-2020 + * Copyright by Contributors 2017-2021 */ #include #include @@ -287,7 +287,7 @@ class CPUPredictor : public Predictor { } template - void DispatchedInplacePredict(dmlc::any const &x, + void DispatchedInplacePredict(dmlc::any const &x, std::shared_ptr p_m, const gbm::GBTreeModel &model, float missing, PredictionCacheEntry *out_preds, uint32_t tree_begin, uint32_t tree_end) const { @@ -295,33 +295,44 @@ class CPUPredictor : public Predictor { auto m = dmlc::get>(x); CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature) << "Number of columns in data must equal to trained model."; - MetaInfo info; - info.num_col_ = m->NumColumns(); - info.num_row_ = m->NumRows(); - this->InitOutPredictions(info, &(out_preds->predictions), model); - std::vector workspace(info.num_col_ * 8 * threads); + if (p_m) { + p_m->Info().num_row_ = m->NumRows(); + this->InitOutPredictions(p_m->Info(), &(out_preds->predictions), model); + } else { + MetaInfo info; + info.num_row_ = m->NumRows(); + this->InitOutPredictions(info, &(out_preds->predictions), model); + } + std::vector workspace(m->NumColumns() * 8 * threads); auto &predictions = out_preds->predictions.HostVector(); std::vector thread_temp; - InitThreadTemp(threads*kBlockOfRowsSize, model.learner_model_param->num_feature, - &thread_temp); - PredictBatchByBlockOfRowsKernel, - kBlockOfRowsSize>(AdapterView( - m.get(), missing, common::Span{workspace}), - &predictions, model, tree_begin, tree_end, &thread_temp); + InitThreadTemp(threads * kBlockOfRowsSize, + model.learner_model_param->num_feature, &thread_temp); + PredictBatchByBlockOfRowsKernel, kBlockOfRowsSize>( + AdapterView(m.get(), missing, common::Span{workspace}), + &predictions, model, tree_begin, tree_end, &thread_temp); } - void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model, - float missing, PredictionCacheEntry *out_preds, - uint32_t tree_begin, unsigned tree_end) const override { + bool InplacePredict(dmlc::any const &x, std::shared_ptr p_m, + const gbm::GBTreeModel &model, float missing, + PredictionCacheEntry *out_preds, uint32_t tree_begin, + unsigned tree_end) const override { if (x.type() == typeid(std::shared_ptr)) { this->DispatchedInplacePredict( - x, model, missing, out_preds, tree_begin, tree_end); + x, p_m, model, missing, out_preds, tree_begin, tree_end); } else if (x.type() == typeid(std::shared_ptr)) { this->DispatchedInplacePredict( - x, model, missing, out_preds, tree_begin, tree_end); + x, p_m, model, missing, out_preds, tree_begin, tree_end); + } else if (x.type() == typeid(std::shared_ptr)) { + this->DispatchedInplacePredict ( + x, p_m, model, missing, out_preds, tree_begin, tree_end); + } else if (x.type() == typeid(std::shared_ptr)) { + this->DispatchedInplacePredict ( + x, p_m, model, missing, out_preds, tree_begin, tree_end); } else { - LOG(FATAL) << "Data type is not supported by CPU Predictor."; + return false; } + return true; } void PredictInstance(const SparsePage::Inst& inst, diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index f06bf722b..d786229c9 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2017-2020 by Contributors + * Copyright 2017-2021 by Contributors */ #include #include @@ -644,7 +644,7 @@ class GPUPredictor : public xgboost::Predictor { } template - void DispatchedInplacePredict(dmlc::any const &x, + void DispatchedInplacePredict(dmlc::any const &x, std::shared_ptr p_m, const gbm::GBTreeModel &model, float, PredictionCacheEntry *out_preds, uint32_t tree_begin, uint32_t tree_end) const { @@ -659,16 +659,20 @@ class GPUPredictor : public xgboost::Predictor { CHECK_EQ(this->generic_param_->gpu_id, m->DeviceIdx()) << "XGBoost is running on device: " << this->generic_param_->gpu_id << ", " << "but data is on: " << m->DeviceIdx(); - MetaInfo info; - info.num_col_ = m->NumColumns(); - info.num_row_ = m->NumRows(); - this->InitOutPredictions(info, &(out_preds->predictions), model); + if (p_m) { + p_m->Info().num_row_ = m->NumRows(); + this->InitOutPredictions(p_m->Info(), &(out_preds->predictions), model); + } else { + MetaInfo info; + info.num_row_ = m->NumRows(); + this->InitOutPredictions(info, &(out_preds->predictions), model); + } const uint32_t BLOCK_THREADS = 128; - auto GRID_SIZE = static_cast(common::DivRoundUp(info.num_row_, BLOCK_THREADS)); + auto GRID_SIZE = static_cast(common::DivRoundUp(m->NumRows(), BLOCK_THREADS)); size_t shared_memory_bytes = - SharedMemoryBytes(info.num_col_, max_shared_memory_bytes); + SharedMemoryBytes(m->NumColumns(), max_shared_memory_bytes); bool use_shared = shared_memory_bytes != 0; size_t entry_start = 0; @@ -680,23 +684,25 @@ class GPUPredictor : public xgboost::Predictor { d_model.categories_tree_segments.ConstDeviceSpan(), d_model.categories_node_segments.ConstDeviceSpan(), d_model.categories.ConstDeviceSpan(), tree_begin, tree_end, m->NumColumns(), - info.num_row_, entry_start, use_shared, output_groups); + m->NumRows(), entry_start, use_shared, output_groups); } - void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model, - float missing, PredictionCacheEntry *out_preds, - uint32_t tree_begin, unsigned tree_end) const override { + bool InplacePredict(dmlc::any const &x, std::shared_ptr p_m, + const gbm::GBTreeModel &model, float missing, + PredictionCacheEntry *out_preds, uint32_t tree_begin, + unsigned tree_end) const override { if (x.type() == typeid(std::shared_ptr)) { this->DispatchedInplacePredict< data::CupyAdapter, DeviceAdapterLoader>( - x, model, missing, out_preds, tree_begin, tree_end); + x, p_m, model, missing, out_preds, tree_begin, tree_end); } else if (x.type() == typeid(std::shared_ptr)) { this->DispatchedInplacePredict< data::CudfAdapter, DeviceAdapterLoader>( - x, model, missing, out_preds, tree_begin, tree_end); + x, p_m, model, missing, out_preds, tree_begin, tree_end); } else { - LOG(FATAL) << "Only CuPy and CuDF are supported by GPU Predictor."; + return false; } + return true; } void PredictContribution(DMatrix* p_fmat, diff --git a/tests/cpp/predictor/test_predictor.cc b/tests/cpp/predictor/test_predictor.cc index 9206ba2aa..7be88fea0 100644 --- a/tests/cpp/predictor/test_predictor.cc +++ b/tests/cpp/predictor/test_predictor.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2020 by Contributors + * Copyright 2020-2021 by Contributors */ #include @@ -104,21 +104,24 @@ void TestInplacePrediction(dmlc::any x, std::string predictor, } HostDeviceVector *p_out_predictions_0{nullptr}; - learner->InplacePredict(x, "margin", std::numeric_limits::quiet_NaN(), + learner->InplacePredict(x, nullptr, PredictionType::kMargin, + std::numeric_limits::quiet_NaN(), &p_out_predictions_0, 0, 2); CHECK(p_out_predictions_0); HostDeviceVector predict_0 (p_out_predictions_0->Size()); predict_0.Copy(*p_out_predictions_0); HostDeviceVector *p_out_predictions_1{nullptr}; - learner->InplacePredict(x, "margin", std::numeric_limits::quiet_NaN(), + learner->InplacePredict(x, nullptr, PredictionType::kMargin, + std::numeric_limits::quiet_NaN(), &p_out_predictions_1, 2, 4); CHECK(p_out_predictions_1); HostDeviceVector predict_1 (p_out_predictions_1->Size()); predict_1.Copy(*p_out_predictions_1); HostDeviceVector* p_out_predictions{nullptr}; - learner->InplacePredict(x, "margin", std::numeric_limits::quiet_NaN(), + learner->InplacePredict(x, nullptr, PredictionType::kMargin, + std::numeric_limits::quiet_NaN(), &p_out_predictions, 0, 4); auto& h_pred = p_out_predictions->HostVector(); diff --git a/tests/python-gpu/test_device_quantile_dmatrix.py b/tests/python-gpu/test_device_quantile_dmatrix.py index 2695a1168..348d75842 100644 --- a/tests/python-gpu/test_device_quantile_dmatrix.py +++ b/tests/python-gpu/test_device_quantile_dmatrix.py @@ -11,8 +11,7 @@ import testing as tm class TestDeviceQuantileDMatrix: def test_dmatrix_numpy_init(self): data = np.random.randn(5, 5) - with pytest.raises(TypeError, - match='is not supported for DeviceQuantileDMatrix'): + with pytest.raises(TypeError, match='is not supported'): xgb.DeviceQuantileDMatrix(data, np.ones(5, dtype=np.float64)) @pytest.mark.skipif(**tm.no_cupy()) diff --git a/tests/python-gpu/test_gpu_prediction.py b/tests/python-gpu/test_gpu_prediction.py index 7b16cb15d..2688d3ac2 100644 --- a/tests/python-gpu/test_gpu_prediction.py +++ b/tests/python-gpu/test_gpu_prediction.py @@ -141,6 +141,13 @@ class TestGPUPredict: assert np.allclose(cpu_train_score, gpu_train_score) assert np.allclose(cpu_test_score, gpu_test_score) + def run_inplace_base_margin(self, booster, dtrain, X, base_margin): + import cupy as cp + dtrain.set_info(base_margin=base_margin) + from_inplace = booster.inplace_predict(data=X, base_margin=base_margin) + from_dmatrix = booster.predict(dtrain) + cp.testing.assert_allclose(from_inplace, from_dmatrix) + @pytest.mark.skipif(**tm.no_cupy()) def test_inplace_predict_cupy(self): import cupy as cp @@ -175,6 +182,9 @@ class TestGPUPredict: for i in range(10): run_threaded_predict(X, rows, predict_dense) + base_margin = cp_rng.randn(rows) + self.run_inplace_base_margin(booster, dtrain, X, base_margin) + @pytest.mark.skipif(**tm.no_cudf()) def test_inplace_predict_cudf(self): import cupy as cp @@ -208,6 +218,9 @@ class TestGPUPredict: for i in range(10): run_threaded_predict(X, rows, predict_df) + base_margin = cudf.Series(rng.randn(rows)) + self.run_inplace_base_margin(booster, dtrain, X, base_margin) + @given(strategies.integers(1, 10), tm.dataset_strategy, shap_parameter_strategy) @settings(deadline=None) diff --git a/tests/python/test_predict.py b/tests/python/test_predict.py index ef719bd47..2dfd40005 100644 --- a/tests/python/test_predict.py +++ b/tests/python/test_predict.py @@ -80,20 +80,28 @@ def test_predict_leaf(): class TestInplacePredict: '''Tests for running inplace prediction''' + @classmethod + def setup_class(cls): + cls.rows = 100 + cls.cols = 10 + + cls.rng = np.random.RandomState(1994) + + cls.X = cls.rng.randn(cls.rows, cls.cols) + cls.y = cls.rng.randn(cls.rows) + + dtrain = xgb.DMatrix(cls.X, cls.y) + + cls.booster = xgb.train({'tree_method': 'hist'}, + dtrain, num_boost_round=10) + + cls.test = xgb.DMatrix(cls.X[:10, ...]) + def test_predict(self): - rows = 1000 - cols = 10 + booster = self.booster + X = self.X + test = self.test - np.random.seed(1994) - - X = np.random.randn(rows, cols) - y = np.random.randn(rows) - dtrain = xgb.DMatrix(X, y) - - booster = xgb.train({'tree_method': 'hist'}, - dtrain, num_boost_round=10) - - test = xgb.DMatrix(X[:10, ...]) predt_from_array = booster.inplace_predict(X[:10, ...]) predt_from_dmatrix = booster.predict(test) @@ -111,7 +119,7 @@ class TestInplacePredict: return np.all(copied_predt == inplace_predt) for i in range(10): - run_threaded_predict(X, rows, predict_dense) + run_threaded_predict(X, self.rows, predict_dense) def predict_csr(x): inplace_predt = booster.inplace_predict(sparse.csr_matrix(x)) @@ -120,4 +128,14 @@ class TestInplacePredict: return np.all(copied_predt == inplace_predt) for i in range(10): - run_threaded_predict(X, rows, predict_csr) + run_threaded_predict(X, self.rows, predict_csr) + + def test_base_margin(self): + booster = self.booster + + base_margin = self.rng.randn(self.rows) + from_inplace = booster.inplace_predict(data=self.X, base_margin=base_margin) + + dtrain = xgb.DMatrix(self.X, self.y, base_margin=base_margin) + from_dmatrix = booster.predict(dtrain) + np.testing.assert_allclose(from_dmatrix, from_inplace) diff --git a/tests/python/test_with_shap.py b/tests/python/test_with_shap.py index 6d0d03720..253ce25e9 100644 --- a/tests/python/test_with_shap.py +++ b/tests/python/test_with_shap.py @@ -1,6 +1,5 @@ import numpy as np import xgboost as xgb -import testing as tm import pytest try: