From 4656b09d5d498c7ad9f14d2b061fcead0ada7ed0 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 8 Feb 2021 18:26:32 +0800 Subject: [PATCH] [breaking] Add prediction fucntion for DMatrix and use inplace predict for dask. (#6668) * Add a new API function for predicting on `DMatrix`. This function aligns with rest of the `XGBoosterPredictFrom*` functions on semantic of function arguments. * Purge `ntree_limit` from libxgboost, use iteration instead. * [dask] Use `inplace_predict` by default for dask sklearn models. * [dask] Run prediction shape inference on worker instead of client. The breaking change is in the Python sklearn `apply` function, I made it to be consistent with other prediction functions where `best_iteration` is used by default. --- include/xgboost/c_api.h | 162 +++++++++- include/xgboost/gbm.h | 40 +-- include/xgboost/learner.h | 9 +- include/xgboost/predictor.h | 27 +- python-package/xgboost/core.py | 221 ++++++++----- python-package/xgboost/dask.py | 378 +++++++++++++--------- python-package/xgboost/sklearn.py | 136 +++++--- python-package/xgboost/training.py | 22 +- src/c_api/c_api.cc | 60 +++- src/c_api/c_api.cu | 6 +- src/c_api/c_api_utils.h | 35 +- src/cli_main.cc | 17 +- src/gbm/gblinear.cc | 27 +- src/gbm/gbtree.cc | 71 ++-- src/gbm/gbtree.h | 48 ++- src/learner.cc | 24 +- src/predictor/cpu_predictor.cc | 54 +--- src/predictor/gpu_predictor.cu | 103 ++---- tests/ci_build/conda_env/cpu_test.yml | 1 + tests/cpp/gbm/test_gbtree.cc | 47 +++ tests/cpp/predictor/test_cpu_predictor.cc | 4 +- tests/cpp/predictor/test_gpu_predictor.cu | 1 - tests/cpp/predictor/test_predictor.cc | 12 +- tests/cpp/test_learner.cc | 7 +- tests/python-gpu/test_from_cupy.py | 15 +- tests/python-gpu/test_gpu_with_dask.py | 68 ++-- tests/python/test_basic_models.py | 6 + tests/python/test_predict.py | 41 ++- tests/python/test_with_dask.py | 96 ++++-- 29 files changed, 1134 insertions(+), 604 deletions(-) diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 38fa7fe80..60c77bb38 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -704,8 +704,9 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle, const char *evnames[], bst_ulong len, const char **out_result); + /*! - * \brief make prediction based on dmat + * \brief make prediction based on dmat (deprecated, use `XGBoosterPredictFromDMatrix` instead) * \param handle handle * \param dmat data matrix * \param option_mask bit-mask of options taken in prediction, possible values @@ -734,6 +735,165 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle, int training, bst_ulong *out_len, const float **out_result); +/*! + * \brief Make prediction from DMatrix, replacing `XGBoosterPredict`. + * + * \param handle Booster handle + * \param dmat DMatrix handle + * \param c_json_config String encoded predict configuration in JSON format. + * + * "type": [0, 5] + * 0: normal prediction + * 1: output margin + * 2: predict contribution + * 3: predict approxmated contribution + * 4: predict feature interaction + * 5: predict leaf + * "training": bool + * Whether the prediction function is used as part of a training loop. **Not used + * for inplace prediction**. + * + * Prediction can be run in 2 scenarios: + * 1. Given data matrix X, obtain prediction y_pred from the model. + * 2. Obtain the prediction for computing gradients. For example, DART booster performs dropout + * during training, and the prediction result will be different from the one obtained by normal + * inference step due to dropped trees. + * Set training=false for the first scenario. Set training=true for the second + * scenario. The second scenario applies when you are defining a custom objective + * function. + * "iteration_begin": int + * Beginning iteration of prediction. + * "iteration_end": int + * End iteration of prediction. Set to 0 this will become the size of tree model. + * "strict_shape": bool + * Whether should we reshape the output with stricter rules. If set to true, + * normal/margin/contrib/interaction predict will output consistent shape + * disregarding the use of multi-class model, and leaf prediction will output 4-dim + * array representing: (n_samples, n_iterations, n_classes, n_trees_in_forest) + * + * Run a normal prediction with strict output shape, 2 dim for softprob , 1 dim for others. + * \code + * { + * "type": 0, + * "training": False, + * "iteration_begin": 0, + * "iteration_end": 0, + * "strict_shape": true, + * } + * \endcode + * + * \param out_shape Shape of output prediction (copy before use). + * \param out_dim Dimension of output prediction. + * \param out_result Buffer storing prediction value (copy before use). + * + * \return 0 when success, -1 when failure happens + */ +XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle, + DMatrixHandle dmat, + char const* c_json_config, + bst_ulong const **out_shape, + bst_ulong *out_dim, + float const **out_result); +/* + * \brief Inplace prediction from CPU dense matrix. + * + * \param handle Booster handle. + * \param values JSON encoded __array_interface__ to values. + * \param c_json_config See `XGBoosterPredictFromDMatrix` for more info. + * + * Additional fields for inplace prediction are: + * "missing": float + * + * \param m An optional (NULL if not available) proxy DMatrix instance + * storing meta info. + * + * \param out_shape See `XGBoosterPredictFromDMatrix` for more info. + * \param out_dim See `XGBoosterPredictFromDMatrix` for more info. + * \param out_result See `XGBoosterPredictFromDMatrix` for more info. + * + * \return 0 when success, -1 when failure happens + */ +XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, + char const *values, + char const *c_json_config, + DMatrixHandle m, + bst_ulong const **out_shape, + bst_ulong *out_dim, + const float **out_result); + +/* + * \brief Inplace prediction from CPU CSR matrix. + * + * \param handle Booster handle. + * \param indptr JSON encoded __array_interface__ to row pointer in CSR. + * \param indices JSON encoded __array_interface__ to column indices in CSR. + * \param values JSON encoded __array_interface__ to values in CSR.. + * \param ncol Number of features in data. + * \param c_json_config See `XGBoosterPredictFromDMatrix` for more info. + * Additional fields for inplace prediction are: + * "missing": float + * + * \param m An optional (NULL if not available) proxy DMatrix instance + * storing meta info. + * + * \param out_shape See `XGBoosterPredictFromDMatrix` for more info. + * \param out_dim See `XGBoosterPredictFromDMatrix` for more info. + * \param out_result See `XGBoosterPredictFromDMatrix` for more info. + * + * \return 0 when success, -1 when failure happens + */ +XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, char const *indptr, + char const *indices, char const *values, + bst_ulong ncol, + char const *c_json_config, DMatrixHandle m, + bst_ulong const **out_shape, + bst_ulong *out_dim, + const float **out_result); + +/* + * \brief Inplace prediction from CUDA Dense matrix (cupy in Python). + * + * \param handle Booster handle + * \param values JSON encoded __cuda_array_interface__ to values. + * \param c_json_config See `XGBoosterPredictFromDMatrix` for more info. + * Additional fields for inplace prediction are: + * "missing": float + * + * \param m An optional (NULL if not available) proxy DMatrix instance + * storing meta info. + * \param out_shape See `XGBoosterPredictFromDMatrix` for more info. + * \param out_dim See `XGBoosterPredictFromDMatrix` for more info. + * \param out_result See `XGBoosterPredictFromDMatrix` for more info. + * + * \return 0 when success, -1 when failure happens + */ +XGB_DLL int XGBoosterPredictFromCudaArray( + BoosterHandle handle, char const *values, char const *c_json_config, + DMatrixHandle m, bst_ulong const **out_shape, bst_ulong *out_dim, + const float **out_result); + +/* + * \brief Inplace prediction from CUDA dense dataframe (cuDF in Python). + * + * \param handle Booster handle + * \param values List of __cuda_array_interface__ for all columns encoded in JSON list. + * \param c_json_config See `XGBoosterPredictFromDMatrix` for more info. + * Additional fields for inplace prediction are: + * "missing": float + * + * \param m An optional (NULL if not available) proxy DMatrix instance + * storing meta info. + * \param out_shape See `XGBoosterPredictFromDMatrix` for more info. + * \param out_dim See `XGBoosterPredictFromDMatrix` for more info. + * \param out_result See `XGBoosterPredictFromDMatrix` for more info. + * + * \return 0 when success, -1 when failure happens + */ +XGB_DLL int XGBoosterPredictFromCudaColumnar( + BoosterHandle handle, char const *values, char const *c_json_config, + DMatrixHandle m, bst_ulong const **out_shape, bst_ulong *out_dim, + const float **out_result); + /* * ========================== Begin Serialization APIs ========================= diff --git a/include/xgboost/gbm.h b/include/xgboost/gbm.h index a49377521..c49fe4747 100644 --- a/include/xgboost/gbm.h +++ b/include/xgboost/gbm.h @@ -63,7 +63,7 @@ class GradientBooster : public Model, public Configurable { /*! * \brief Slice a model using boosting index. The slice m:n indicates taking all trees * that were fit during the boosting rounds m, (m+1), (m+2), ..., (n-1). - * \param layer_begin Begining of boosted tree layer used for prediction. + * \param layer_begin Beginning of boosted tree layer used for prediction. * \param layer_end End of booster layer. 0 means do not limit trees. * \param out Output gradient booster */ @@ -99,15 +99,14 @@ class GradientBooster : public Model, public Configurable { * \param out_preds output vector to hold the predictions * \param training Whether the prediction value is used for training. For dart booster * drop out is performed during training. - * \param ntree_limit limit the number of trees used in prediction, - * when it equals 0, this means we do not limit - * number of trees, this parameter is only valid - * for gbtree, but not for gblinear + * \param layer_begin Beginning of boosted tree layer used for prediction. + * \param layer_end End of booster layer. 0 means do not limit trees. */ virtual void PredictBatch(DMatrix* dmat, PredictionCacheEntry* out_preds, bool training, - unsigned ntree_limit = 0) = 0; + unsigned layer_begin, + unsigned layer_end) = 0; /*! * \brief Inplace prediction. @@ -115,7 +114,7 @@ class GradientBooster : public Model, public Configurable { * \param x A type erased data adapter. * \param missing Missing value in the data. * \param [in,out] out_preds The output preds. - * \param layer_begin (Optional) Begining of boosted tree layer used for prediction. + * \param layer_begin (Optional) Beginning of boosted tree layer used for prediction. * \param layer_end (Optional) End of booster layer. 0 means do not limit trees. */ virtual void InplacePredict(dmlc::any const &, std::shared_ptr, float, @@ -132,44 +131,45 @@ class GradientBooster : public Model, public Configurable { * * \param inst the instance you want to predict * \param out_preds output vector to hold the predictions - * \param ntree_limit limit the number of trees used in prediction + * \param layer_begin Beginning of boosted tree layer used for prediction. + * \param layer_end End of booster layer. 0 means do not limit trees. * \sa Predict */ virtual void PredictInstance(const SparsePage::Inst& inst, std::vector* out_preds, - unsigned ntree_limit = 0) = 0; + unsigned layer_begin, unsigned layer_end) = 0; /*! * \brief predict the leaf index of each tree, the output will be nsample * ntree vector * this is only valid in gbtree predictor * \param dmat feature matrix * \param out_preds output vector to hold the predictions - * \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means - * we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear + * \param layer_begin Beginning of boosted tree layer used for prediction. + * \param layer_end End of booster layer. 0 means do not limit trees. */ - virtual void PredictLeaf(DMatrix* dmat, - HostDeviceVector* out_preds, - unsigned ntree_limit = 0) = 0; + virtual void PredictLeaf(DMatrix *dmat, + HostDeviceVector *out_preds, + unsigned layer_begin, unsigned layer_end) = 0; /*! * \brief feature contributions to individual predictions; the output will be a vector * of length (nfeats + 1) * num_output_group * nsample, arranged in that order * \param dmat feature matrix * \param out_contribs output vector to hold the contributions - * \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means - * we do not limit number of trees + * \param layer_begin Beginning of boosted tree layer used for prediction. + * \param layer_end End of booster layer. 0 means do not limit trees. * \param approximate use a faster (inconsistent) approximation of SHAP values * \param condition condition on the condition_feature (0=no, -1=cond off, 1=cond on). * \param condition_feature feature to condition on (i.e. fix) during calculations */ virtual void PredictContribution(DMatrix* dmat, HostDeviceVector* out_contribs, - unsigned ntree_limit = 0, + unsigned layer_begin, unsigned layer_end, bool approximate = false, int condition = 0, unsigned condition_feature = 0) = 0; - virtual void PredictInteractionContributions(DMatrix* dmat, - HostDeviceVector* out_contribs, - unsigned ntree_limit, bool approximate) = 0; + virtual void PredictInteractionContributions( + DMatrix *dmat, HostDeviceVector *out_contribs, + unsigned layer_begin, unsigned layer_end, bool approximate) = 0; /*! * \brief dump the model in the requested format diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index d2bd51080..8676e5a25 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -113,8 +113,8 @@ class Learner : public Model, public Configurable, public dmlc::Serializable { * \param data input data * \param output_margin whether to only predict margin value instead of transformed prediction * \param out_preds output vector that stores the prediction - * \param ntree_limit limit number of trees used for boosted tree - * predictor, when it equals 0, this means we are using all the trees + * \param layer_begin Beginning of boosted tree layer used for prediction. + * \param layer_end End of booster layer. 0 means do not limit trees. * \param training Whether the prediction result is used for training * \param pred_leaf whether to only predict the leaf index of each tree in a boosted tree predictor * \param pred_contribs whether to only predict the feature contributions @@ -124,7 +124,8 @@ class Learner : public Model, public Configurable, public dmlc::Serializable { virtual void Predict(std::shared_ptr data, bool output_margin, HostDeviceVector *out_preds, - unsigned ntree_limit = 0, + unsigned layer_begin, + unsigned layer_end, bool training = false, bool pred_leaf = false, bool pred_contribs = false, @@ -140,7 +141,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable { * \param type Prediction type. * \param missing Missing value in the data. * \param [in,out] out_preds Pointer to output prediction vector. - * \param layer_begin Begining of boosted tree layer used for prediction. + * \param layer_begin Beginning of boosted tree layer used for prediction. * \param layer_end End of booster layer. 0 means do not limit trees. */ virtual void InplacePredict(dmlc::any const &x, diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index 442cf91dc..4664ada3e 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -127,12 +127,11 @@ class Predictor { * \param [in,out] out_preds The output preds. * \param model The model to predict from. * \param tree_begin The tree begin index. - * \param ntree_limit (Optional) The ntree limit. 0 means do not - * limit trees. + * \param tree_end The tree end index. */ virtual void PredictBatch(DMatrix* dmat, PredictionCacheEntry* out_preds, - const gbm::GBTreeModel& model, int tree_begin, - uint32_t const ntree_limit = 0) const = 0; + const gbm::GBTreeModel& model, uint32_t tree_begin, + uint32_t tree_end = 0) const = 0; /** * \brief Inplace prediction. @@ -140,7 +139,7 @@ class Predictor { * \param model The model to predict from. * \param missing Missing value in the data. * \param [in,out] out_preds The output preds. - * \param tree_begin (Optional) Begining of boosted trees used for prediction. + * \param tree_begin (Optional) Beginning of boosted trees used for prediction. * \param tree_end (Optional) End of booster trees. 0 means do not limit trees. * * \return True if the data can be handled by current predictor, false otherwise. @@ -159,13 +158,13 @@ class Predictor { * \param inst The instance to predict. * \param [in,out] out_preds The output preds. * \param model The model to predict from - * \param ntree_limit (Optional) The ntree limit. + * \param tree_end (Optional) The tree end index. */ virtual void PredictInstance(const SparsePage::Inst& inst, std::vector* out_preds, const gbm::GBTreeModel& model, - unsigned ntree_limit = 0) const = 0; + unsigned tree_end = 0) const = 0; /** * \brief predict the leaf index of each tree, the output will be nsample * @@ -174,18 +173,14 @@ class Predictor { * \param [in,out] dmat The input feature matrix. * \param [in,out] out_preds The output preds. * \param model Model to make predictions from. - * \param ntree_limit (Optional) The ntree limit. + * \param tree_end (Optional) The tree end index. */ virtual void PredictLeaf(DMatrix* dmat, HostDeviceVector* out_preds, const gbm::GBTreeModel& model, - unsigned ntree_limit = 0) const = 0; + unsigned tree_end = 0) const = 0; /** - * \fn virtual void Predictor::PredictContribution( DMatrix* dmat, - * std::vector* out_contribs, const gbm::GBTreeModel& model, - * unsigned ntree_limit = 0) = 0; - * * \brief feature contributions to individual predictions; the output will be * a vector of length (nfeats + 1) * num_output_group * nsample, arranged in * that order. @@ -193,7 +188,7 @@ class Predictor { * \param [in,out] dmat The input feature matrix. * \param [in,out] out_contribs The output feature contribs. * \param model Model to make predictions from. - * \param ntree_limit (Optional) The ntree limit. + * \param tree_end The tree end index. * \param tree_weights (Optional) Weights to multiply each tree by. * \param approximate Use fast approximate algorithm. * \param condition Condition on the condition_feature (0=no, -1=cond off, 1=cond on). @@ -203,7 +198,7 @@ class Predictor { virtual void PredictContribution(DMatrix* dmat, HostDeviceVector* out_contribs, const gbm::GBTreeModel& model, - unsigned ntree_limit = 0, + unsigned tree_end = 0, std::vector* tree_weights = nullptr, bool approximate = false, int condition = 0, @@ -212,7 +207,7 @@ class Predictor { virtual void PredictInteractionContributions(DMatrix* dmat, HostDeviceVector* out_contribs, const gbm::GBTreeModel& model, - unsigned ntree_limit = 0, + unsigned tree_end = 0, std::vector* tree_weights = nullptr, bool approximate = false) const = 0; diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 765fea4d7..8b33838df 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -96,6 +96,24 @@ def from_cstr_to_pystr(data, length): return res +def _convert_ntree_limit(booster, ntree_limit, iteration_range): + if ntree_limit is not None and ntree_limit != 0: + warnings.warn( + "ntree_limit is deprecated, use `iteration_range` or model " + "slicing instead.", + UserWarning + ) + if iteration_range is not None and iteration_range[1] != 0: + raise ValueError( + "Only one of `iteration_range` and `ntree_limit` can be non zero." + ) + num_parallel_tree, num_groups = _get_booster_layer_trees(booster) + num_parallel_tree = max([num_parallel_tree, 1]) + num_groups = max([num_groups, 1]) + iteration_range = (0, ntree_limit // num_parallel_tree) + return iteration_range + + def _expect(expectations, got): """Translate input error into string. @@ -1111,6 +1129,34 @@ Objective = Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]] Metric = Callable[[np.ndarray, DMatrix], Tuple[str, float]] +def _get_booster_layer_trees(model: "Booster") -> Tuple[int, int]: + """Get number of trees added to booster per-iteration. This function will be removed + once `best_ntree_limit` is dropped in favor of `best_iteration`. Returns + `num_parallel_tree` and `num_groups`. + + """ + config = json.loads(model.save_config()) + booster = config["learner"]["gradient_booster"]["name"] + if booster == "gblinear": + num_parallel_tree = 0 + elif booster == "dart": + num_parallel_tree = int( + config["learner"]["gradient_booster"]["gbtree"]["gbtree_train_param"][ + "num_parallel_tree" + ] + ) + elif booster == "gbtree": + num_parallel_tree = int( + config["learner"]["gradient_booster"]["gbtree_train_param"][ + "num_parallel_tree" + ] + ) + else: + raise ValueError(f"Unknown booster: {booster}") + num_groups = int(config["learner"]["learner_model_param"]["num_class"]) + return num_parallel_tree, num_groups + + class Booster(object): # pylint: disable=too-many-public-methods """A Booster of XGBoost. @@ -1497,16 +1543,20 @@ class Booster(object): return self.eval_set([(data, name)], iteration) # pylint: disable=too-many-function-args - def predict(self, - data, - output_margin=False, - ntree_limit=0, - pred_leaf=False, - pred_contribs=False, - approx_contribs=False, - pred_interactions=False, - validate_features=True, - training=False): + def predict( + self, + data: DMatrix, + output_margin: bool = False, + ntree_limit: int = 0, + pred_leaf: bool = False, + pred_contribs: bool = False, + approx_contribs: bool = False, + pred_interactions: bool = False, + validate_features: bool = True, + training: bool = False, + iteration_range: Tuple[int, int] = (0, 0), + strict_shape: bool = False, + ) -> np.ndarray: """Predict with data. .. note:: This function is not thread safe except for ``gbtree`` booster. @@ -1518,33 +1568,32 @@ class Booster(object): Parameters ---------- - data : DMatrix + data : The dmatrix storing the input. - output_margin : bool + output_margin : Whether to output the raw untransformed margin value. - ntree_limit : int - Limit number of trees in the prediction; defaults to 0 (use all - trees). + ntree_limit : + Deprecated, use `iteration_range` instead. - pred_leaf : bool + pred_leaf : When this option is on, the output will be a matrix of (nsample, ntrees) with each record indicating the predicted leaf index of each sample in each tree. Note that the leaf index of a tree is unique per tree, so you may find leaf 1 in both tree 1 and tree 0. - pred_contribs : bool + pred_contribs : When this is True the output will be a matrix of size (nsample, nfeats + 1) with each record indicating the feature contributions (SHAP values) for that prediction. The sum of all feature contributions is equal to the raw untransformed margin value of the prediction. Note the final column is the bias term. - approx_contribs : bool + approx_contribs : Approximate the contributions of each feature - pred_interactions : bool + pred_interactions : When this is True the output will be a matrix of size (nsample, nfeats + 1, nfeats + 1) indicating the SHAP interaction values for each pair of features. The sum of each row (or column) of the @@ -1553,17 +1602,33 @@ class Booster(object): untransformed margin value of the prediction. Note the last row and column correspond to the bias term. - validate_features : bool + validate_features : When this is True, validate that the Booster's and data's feature_names are identical. Otherwise, it is assumed that the feature_names are the same. - training : bool + training : Whether the prediction value is used for training. This can effect `dart` booster, which performs dropouts during training iterations. .. versionadded:: 1.0.0 + iteration_range : + Specifies which layer of trees are used in prediction. For example, if a + random forest is trained with 100 rounds. Specifying `iteration_range=(10, + 20)`, then only the forests built during [10, 20) (half open set) rounds are + used in this prediction. + + .. versionadded:: 1.4.0 + + strict_shape : + When set to True, output shape is invariant to whether classification is used. + For both value and margin prediction, the output shape is (n_samples, + n_groups), n_groups == 1 when multi-class is not used. Default to False, in + which case the output shape can be (n_samples, ) if multi-class is not used. + + .. versionadded:: 1.4.0 + .. note:: Using ``predict()`` with DART booster If the booster object is DART type, ``predict()`` will not perform @@ -1575,64 +1640,50 @@ class Booster(object): prediction : numpy array """ - option_mask = 0x00 - if output_margin: - option_mask |= 0x01 - if pred_leaf: - option_mask |= 0x02 - if pred_contribs: - option_mask |= 0x04 - if approx_contribs: - option_mask |= 0x08 - if pred_interactions: - option_mask |= 0x10 - if not isinstance(data, DMatrix): - raise TypeError('Expecting data to be a DMatrix object, got: ', - type(data)) - + raise TypeError('Expecting data to be a DMatrix object, got: ', type(data)) if validate_features: self._validate_features(data) + iteration_range = _convert_ntree_limit(self, ntree_limit, iteration_range) + args = { + "type": 0, + "training": training, + "iteration_begin": iteration_range[0], + "iteration_end": iteration_range[1], + "strict_shape": strict_shape, + } - length = c_bst_ulong() - preds = ctypes.POINTER(ctypes.c_float)() - _check_call(_LIB.XGBoosterPredict(self.handle, data.handle, - ctypes.c_int(option_mask), - ctypes.c_uint(ntree_limit), - ctypes.c_int(training), - ctypes.byref(length), - ctypes.byref(preds))) - preds = ctypes2numpy(preds, length.value, np.float32) + def assign_type(t: int) -> None: + if args["type"] != 0: + raise ValueError("One type of prediction at a time.") + args["type"] = t + + if output_margin: + assign_type(1) + if pred_contribs: + assign_type(2 if not approx_contribs else 3) + if pred_interactions: + assign_type(4) if pred_leaf: - preds = preds.astype(np.int32, copy=False) - nrow = data.num_row() - if preds.size != nrow and preds.size % nrow == 0: - chunk_size = int(preds.size / nrow) - - if pred_interactions: - ngroup = int(chunk_size / ((data.num_col() + 1) * - (data.num_col() + 1))) - if ngroup == 1: - preds = preds.reshape(nrow, - data.num_col() + 1, - data.num_col() + 1) - else: - preds = preds.reshape(nrow, ngroup, - data.num_col() + 1, - data.num_col() + 1) - elif pred_contribs: - ngroup = int(chunk_size / (data.num_col() + 1)) - if ngroup == 1: - preds = preds.reshape(nrow, data.num_col() + 1) - else: - preds = preds.reshape(nrow, ngroup, data.num_col() + 1) - else: - preds = preds.reshape(nrow, chunk_size) - return preds + assign_type(5) + preds = ctypes.POINTER(ctypes.c_float)() + shape = ctypes.POINTER(c_bst_ulong)() + dims = c_bst_ulong() + _check_call( + _LIB.XGBoosterPredictFromDMatrix( + self.handle, + data.handle, + from_pystr_to_cstr(json.dumps(args)), + ctypes.byref(shape), + ctypes.byref(dims), + ctypes.byref(preds) + ) + ) + return _prediction_output(shape, dims, preds, False) def inplace_predict( self, - data, + data: Any, iteration_range: Tuple[int, int] = (0, 0), predict_type: str = "value", missing: float = np.nan, @@ -1665,26 +1716,24 @@ class Booster(object): The input data, must not be a view for numpy array. Set ``predictor`` to ``gpu_predictor`` for running prediction on CuPy array or CuDF DataFrame. - iteration_range : tuple - Specifies which layer of trees are used in prediction. For - example, if a random forest is trained with 100 rounds. Specifying - `iteration_range=(10, 20)`, then only the forests built during [10, - 20) (open set) rounds are used in this prediction. - predict_type : str + iteration_range : + See :py:meth:`xgboost.Booster.predict` for details. + predict_type : * `value` Output model prediction values. * `margin` Output the raw untransformed margin value. - missing : float - Value in the input data which needs to be present as a missing - value. + missing : + See :py:obj:`xgboost.DMatrix` for details. validate_features: See :py:meth:`xgboost.Booster.predict` for details. base_margin: See :py:obj:`xgboost.DMatrix` for details. + + .. versionadded:: 1.4.0 + strict_shape: - When set to True, output shape is invariant to whether classification is used. - For both value and margin prediction, the output shape is (n_samples, - n_groups), n_groups == 1 when multi-class is not used. Default to False, in - which case the output shape can be (n_samples, ) if multi-class is not used. + See :py:meth:`xgboost.Booster.predict` for details. + + .. versionadded:: 1.4.0 Returns ------- @@ -1772,7 +1821,7 @@ class Booster(object): interface["mask"] = interface["mask"].__cuda_array_interface__ interface_str = bytes(json.dumps(interface, indent=2), "utf-8") _check_call( - _LIB.XGBoosterPredictFromArrayInterface( + _LIB.XGBoosterPredictFromCudaArray( self.handle, interface_str, from_pystr_to_cstr(json.dumps(args)), @@ -1788,7 +1837,7 @@ class Booster(object): interfaces_str = _cudf_array_interfaces(data) _check_call( - _LIB.XGBoosterPredictFromArrayInterfaceColumns( + _LIB.XGBoosterPredictFromCudaColumnar( self.handle, interfaces_str, from_pystr_to_cstr(json.dumps(args)), diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index d05fe4951..76b1b50a1 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -254,6 +254,7 @@ class DaskDMatrix: raise TypeError(_expect((dd.DataFrame, da.Array, dd.Series), type(label))) self._n_cols = data.shape[1] + assert isinstance(self._n_cols, int) self.worker_map: Dict[str, "distributed.Future"] = defaultdict(list) self.is_quantile: bool = False @@ -881,7 +882,7 @@ async def _train_async( return list(filter(lambda ret: ret is not None, results))[0] -def train( +def train( # pylint: disable=unused-argument client: "distributed.Client", params: Dict[str, Any], dtrain: DaskDMatrix, @@ -892,16 +893,17 @@ def train( early_stopping_rounds: Optional[int] = None, xgb_model: Optional[Booster] = None, verbose_eval: Union[int, bool] = True, - callbacks: Optional[List[TrainingCallback]] = None + callbacks: Optional[List[TrainingCallback]] = None, ) -> Any: - '''Train XGBoost model. + """Train XGBoost model. .. versionadded:: 1.0.0 .. note:: - Other parameters are the same as `xgboost.train` except for `evals_result`, which - is returned as part of function return value instead of argument. + Other parameters are the same as :py:func:`xgboost.train` except for + `evals_result`, which is returned as part of function return value instead of + argument. Parameters ---------- @@ -920,29 +922,17 @@ def train( {'booster': xgboost.Booster, 'history': {'train': {'logloss': ['0.48253', '0.35953']}, 'eval': {'logloss': ['0.480385', '0.357756']}}} - ''' + + """ _assert_dask_support() client = _xgb_get_client(client) # Get global configuration before transferring computation to another thread or # process. - global_config = config.get_config() - return client.sync(_train_async, - client=client, - global_config=global_config, - num_boost_round=num_boost_round, - obj=obj, - feval=feval, - params=params, - dtrain=dtrain, - evals=evals, - early_stopping_rounds=early_stopping_rounds, - verbose_eval=verbose_eval, - xgb_model=xgb_model, - callbacks=callbacks) + return client.sync(_train_async, global_config=config.get_config(), **locals()) -def _can_output_df(data: _DaskCollection, output_shape: Tuple) -> bool: - return isinstance(data, dd.DataFrame) and len(output_shape) <= 2 +def _can_output_df(is_df: bool, output_shape: Tuple) -> bool: + return is_df and len(output_shape) <= 2 async def _direct_predict_impl( @@ -954,8 +944,9 @@ async def _direct_predict_impl( meta: Dict[int, str], ) -> _DaskCollection: columns = list(meta.keys()) - if _can_output_df(data, output_shape): + if _can_output_df(isinstance(data, dd.DataFrame), output_shape): if base_margin is not None and isinstance(base_margin, da.Array): + # Easier for map_partitions base_margin_df: Optional[dd.DataFrame] = base_margin.to_dask_dataframe() else: base_margin_df = base_margin @@ -975,17 +966,21 @@ async def _direct_predict_impl( if base_margin is not None and isinstance( base_margin, (dd.Series, dd.DataFrame) ): + # Easier for map_blocks base_margin_array: Optional[da.Array] = base_margin.to_dask_array() else: base_margin_array = base_margin # Input data is 2-dim array, output can be 1(reg, binary)/2(multi-class, - # contrib)/3(contrib)/4(interaction) dims. + # contrib)/3(contrib, interaction)/4(interaction) dims. if len(output_shape) == 1: drop_axis: Union[int, List[int]] = [1] # drop from 2 to 1 dim. new_axis: Union[int, List[int]] = [] else: drop_axis = [] - new_axis = [i + 2 for i in range(len(output_shape) - 2)] + if isinstance(data, dd.DataFrame): + new_axis = list(range(len(output_shape) - 2)) + else: + new_axis = [i + 2 for i in range(len(output_shape) - 2)] predictions = da.map_blocks( mapped_predict, booster, @@ -1001,28 +996,21 @@ async def _direct_predict_impl( def _infer_predict_output( - booster: Booster, - data: Union[DaskDMatrix, _DaskCollection], - inplace: bool, - **kwargs: Any + booster: Booster, features: int, is_df: bool, inplace: bool, **kwargs: Any ) -> Tuple[Tuple[int, ...], Dict[int, str]]: """Create a dummy test sample to infer output shape for prediction.""" - if isinstance(data, DaskDMatrix): - features = data.num_col() - else: - features = data.shape[1] + assert isinstance(features, int) rng = numpy.random.RandomState(1994) test_sample = rng.randn(1, features) if inplace: - # clear the state to avoid gpu_id, gpu_predictor - booster = Booster(model_file=booster.save_raw()) - test_predt = booster.inplace_predict(test_sample, **kwargs) - else: - m = DMatrix(test_sample) - test_predt = booster.predict(m, **kwargs) + kwargs = kwargs.copy() + if kwargs.pop("predict_type") == "margin": + kwargs["output_margin"] = True + m = DMatrix(test_sample) + test_predt = booster.predict(m, validate_features=False, **kwargs) n_columns = test_predt.shape[1] if len(test_predt.shape) > 1 else 1 meta: Dict[int, str] = {} - if _can_output_df(data, test_predt.shape): + if _can_output_df(is_df, test_predt.shape): for i in range(n_columns): meta[i] = "f4" return test_predt.shape, meta @@ -1034,7 +1022,7 @@ async def _get_model_future( if isinstance(model, Booster): booster = await client.scatter(model, broadcast=True) elif isinstance(model, dict): - booster = await client.scatter(model["booster"]) + booster = await client.scatter(model["booster"], broadcast=True) elif isinstance(model, distributed.Future): booster = model if booster.type is not Booster: @@ -1059,6 +1047,8 @@ async def _predict_async( approx_contribs: bool, pred_interactions: bool, validate_features: bool, + iteration_range: Tuple[int, int], + strict_shape: bool, ) -> _DaskCollection: _booster = await _get_model_future(client, model) if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)): @@ -1077,43 +1067,51 @@ async def _predict_async( approx_contribs=approx_contribs, pred_interactions=pred_interactions, validate_features=validate_features, + iteration_range=iteration_range, + strict_shape=strict_shape, ) - if is_df and len(predt.shape) <= 2: + if _can_output_df(is_df, predt.shape): if lazy_isinstance(partition, "cudf", "core.dataframe.DataFrame"): import cudf - predt = cudf.DataFrame(predt, columns=columns) + predt = cudf.DataFrame(predt, columns=columns, dtype=numpy.float32) else: - predt = DataFrame(predt, columns=columns) + predt = DataFrame(predt, columns=columns, dtype=numpy.float32) return predt # Predict on dask collection directly. if isinstance(data, (da.Array, dd.DataFrame)): - _output_shape, meta = _infer_predict_output( - await _booster.result(), - data, + _output_shape, meta = await client.compute( + client.submit( + _infer_predict_output, + _booster, + features=data.shape[1], + is_df=isinstance(data, dd.DataFrame), + inplace=False, + output_margin=output_margin, + pred_leaf=pred_leaf, + pred_contribs=pred_contribs, + approx_contribs=approx_contribs, + pred_interactions=pred_interactions, + ) + ) + return await _direct_predict_impl( + mapped_predict, _booster, data, None, _output_shape, meta + ) + + output_shape, _ = await client.compute( + client.submit( + _infer_predict_output, + booster=_booster, + features=data.num_col(), + is_df=False, inplace=False, output_margin=output_margin, pred_leaf=pred_leaf, pred_contribs=pred_contribs, approx_contribs=approx_contribs, pred_interactions=pred_interactions, - validate_features=False, ) - return await _direct_predict_impl( - mapped_predict, _booster, data, None, _output_shape, meta - ) - - output_shape, _ = _infer_predict_output( - booster=await _booster.result(), - data=data, - inplace=False, - output_margin=output_margin, - pred_leaf=pred_leaf, - pred_contribs=pred_contribs, - approx_contribs=approx_contribs, - pred_interactions=pred_interactions, - validate_features=False, ) # Prediction on dask DMatrix. partition_order = data.partition_order @@ -1180,7 +1178,7 @@ async def _predict_async( futures[i], shape=(rows,) + output_shape[1:], dtype=numpy.float32 ) ) - predictions = await da.concatenate(arrays, axis=0) + predictions = da.concatenate(arrays, axis=0) return predictions @@ -1194,15 +1192,19 @@ def predict( # pylint: disable=unused-argument pred_contribs: bool = False, approx_contribs: bool = False, pred_interactions: bool = False, - validate_features: bool = True + validate_features: bool = True, + iteration_range: Tuple[int, int] = (0, 0), + strict_shape: bool = False, ) -> Any: '''Run prediction with a trained booster. .. note:: - Using ``inplace_predict `` might be faster when meta information like - ``base_margin`` is not needed. For other parameters, please see - ``Booster.predict``. + Using ``inplace_predict`` might be faster when some features are not needed. See + :py:meth:`xgboost.Booster.predict` for details on various parameters. When using + ``pred_interactions`` with mutli-class model, input should be ``da.Array`` or + ``DaskDMatrix`` due to limitation in ``da.map_blocks``. + .. versionadded:: 1.0.0 @@ -1232,69 +1234,83 @@ def predict( # pylint: disable=unused-argument ''' _assert_dask_support() client = _xgb_get_client(client) - return client.sync( - _predict_async, global_config=config.get_config(), **locals() - ) + return client.sync(_predict_async, global_config=config.get_config(), **locals()) -async def _inplace_predict_async( +async def _inplace_predict_async( # pylint: disable=too-many-branches client: "distributed.Client", global_config: Dict[str, Any], model: Union[Booster, Dict, "distributed.Future"], data: _DaskCollection, - iteration_range: Tuple[int, int] = (0, 0), - predict_type: str = 'value', - missing: float = numpy.nan + iteration_range: Tuple[int, int], + predict_type: str, + missing: float, + validate_features: bool, + base_margin: Optional[_DaskCollection], + strict_shape: bool, ) -> _DaskCollection: client = _xgb_get_client(client) booster = await _get_model_future(client, model) if not isinstance(data, (da.Array, dd.DataFrame)): raise TypeError(_expect([da.Array, dd.DataFrame], type(data))) + if base_margin is not None and not isinstance( + data, (da.Array, dd.DataFrame, dd.Series) + ): + raise TypeError(_expect([da.Array, dd.DataFrame, dd.Series], type(base_margin))) def mapped_predict( - booster: Booster, data: Any, is_df: bool, columns: List[int], _: Any + booster: Booster, data: Any, is_df: bool, columns: List[int], base_margin: Any ) -> Any: with config.config_context(**global_config): prediction = booster.inplace_predict( data, iteration_range=iteration_range, predict_type=predict_type, - missing=missing + missing=missing, + base_margin=base_margin, + validate_features=validate_features, + strict_shape=strict_shape, ) - if is_df and len(prediction.shape) <= 2: - if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'): + if _can_output_df(is_df, prediction.shape): + if lazy_isinstance(data, "cudf.core.dataframe", "DataFrame"): import cudf + prediction = cudf.DataFrame( prediction, columns=columns, dtype=numpy.float32 ) else: - # If it's from pandas, the partition is a numpy array - prediction = DataFrame( - prediction, columns=columns, dtype=numpy.float32 - ) + # If it's from pandas, the partition is a numpy array + prediction = DataFrame(prediction, columns=columns, dtype=numpy.float32) return prediction - - shape, meta = _infer_predict_output( - await booster.result(), - data, - True, - predict_type=predict_type, - iteration_range=iteration_range + # await turns future into value. + shape, meta = await client.compute( + client.submit( + _infer_predict_output, + booster, + features=data.shape[1], + is_df=isinstance(data, dd.DataFrame), + inplace=True, + predict_type=predict_type, + iteration_range=iteration_range, + ) ) return await _direct_predict_impl( - mapped_predict, booster, data, None, shape, meta + mapped_predict, booster, data, base_margin, shape, meta ) -def inplace_predict( # pylint: disable=unused-argument +def inplace_predict( # pylint: disable=unused-argument client: "distributed.Client", model: Union[TrainReturnT, Booster, "distributed.Future"], data: _DaskCollection, iteration_range: Tuple[int, int] = (0, 0), - predict_type: str = 'value', - missing: float = numpy.nan + predict_type: str = "value", + missing: float = numpy.nan, + validate_features: bool = True, + base_margin: Optional[_DaskCollection] = None, + strict_shape: bool = False, ) -> Any: - '''Inplace prediction. + """Inplace prediction. See doc in :py:meth:`xgboost.Booster.inplace_predict` for details. .. versionadded:: 1.1.0 @@ -1304,16 +1320,27 @@ def inplace_predict( # pylint: disable=unused-argument Specify the dask client used for training. Use default client returned from dask if it's set to None. model: - The trained model. It can be a distributed.Future so user can - pre-scatter it onto all workers. + See :py:func:`xgboost.dask.predict` for details. + data : + dask collection. iteration_range: - Specify the range of trees used for prediction. + See :py:meth:`xgboost.Booster.predict` for details. predict_type: - * 'value': Normal prediction result. - * 'margin': Output the raw untransformed margin value. + See :py:meth:`xgboost.Booster.inplace_predict` for details. missing: Value in the input data which needs to be present as a missing value. If None, defaults to np.nan. + base_margin: + See :py:obj:`xgboost.DMatrix` for details. Right now classifier is not well + supported with base_margin as it requires the size of base margin to be `n_classes + * n_samples`. + + .. versionadded:: 1.4.0 + + strict_shape: + See :py:meth:`xgboost.Booster.predict` for details. + + .. versionadded:: 1.4.0 Returns ------- @@ -1322,7 +1349,7 @@ def inplace_predict( # pylint: disable=unused-argument data is ``dask.dataframe.DataFrame``, return value can be ``dask.dataframe.Series``, ``dask.dataframe.DataFrame`` or ``dask.array.Array``, depending on the output shape. - ''' + """ _assert_dask_support() client = _xgb_get_client(client) return client.sync( @@ -1334,9 +1361,11 @@ async def _async_wrap_evaluation_matrices( client: "distributed.Client", **kwargs: Any ) -> Tuple[DaskDMatrix, Optional[List[Tuple[DaskDMatrix, str]]]]: """A switch function for async environment.""" + def _inner(**kwargs: Any) -> DaskDMatrix: m = DaskDMatrix(client=client, **kwargs) return m + train_dmatrix, evals = _wrap_evaluation_matrices(create_dmatrix=_inner, **kwargs) train_dmatrix = await train_dmatrix if evals is None: @@ -1351,25 +1380,45 @@ async def _async_wrap_evaluation_matrices( class DaskScikitLearnBase(XGBModel): - '''Base class for implementing scikit-learn interface with Dask''' + """Base class for implementing scikit-learn interface with Dask""" _client = None async def _predict_async( - self, data: _DaskCollection, - output_margin: bool = False, - validate_features: bool = True, - base_margin: Optional[_DaskCollection] = None + self, + data: _DaskCollection, + output_margin: bool, + validate_features: bool, + base_margin: Optional[_DaskCollection], + iteration_range: Optional[Tuple[int, int]], ) -> Any: - test_dmatrix = await DaskDMatrix( - client=self.client, data=data, base_margin=base_margin, - missing=self.missing - ) - pred_probs = await predict(client=self.client, - model=self.get_booster(), data=test_dmatrix, - output_margin=output_margin, - validate_features=validate_features) - return pred_probs + iteration_range = self._get_iteration_range(iteration_range) + if self._can_use_inplace_predict(): + predts = await inplace_predict( + client=self.client, + model=self.get_booster(), + data=data, + iteration_range=iteration_range, + predict_type="margin" if output_margin else "value", + missing=self.missing, + base_margin=base_margin, + validate_features=validate_features, + ) + if isinstance(predts, dd.DataFrame): + predts = predts.to_dask_array() + else: + test_dmatrix = await DaskDMatrix( + self.client, data=data, base_margin=base_margin, missing=self.missing + ) + predts = await predict( + self.client, + model=self.get_booster(), + data=test_dmatrix, + output_margin=output_margin, + validate_features=validate_features, + iteration_range=iteration_range, + ) + return predts def predict( self, @@ -1377,26 +1426,56 @@ class DaskScikitLearnBase(XGBModel): output_margin: bool = False, ntree_limit: Optional[int] = None, validate_features: bool = True, - base_margin: Optional[_DaskCollection] = None + base_margin: Optional[_DaskCollection] = None, + iteration_range: Optional[Tuple[int, int]] = None, ) -> Any: _assert_dask_support() - msg = '`ntree_limit` is not supported on dask, use model slicing instead.' + msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead." assert ntree_limit is None, msg return self.client.sync( self._predict_async, X, output_margin=output_margin, validate_features=validate_features, - base_margin=base_margin + base_margin=base_margin, + iteration_range=iteration_range, ) + async def _apply_async( + self, + X: _DaskCollection, + iteration_range: Optional[Tuple[int, int]] = None, + ) -> Any: + iteration_range = self._get_iteration_range(iteration_range) + test_dmatrix = await DaskDMatrix(self.client, data=X, missing=self.missing) + predts = await predict( + self.client, + model=self.get_booster(), + data=test_dmatrix, + pred_leaf=True, + iteration_range=iteration_range, + ) + return predts + + def apply( + self, + X: _DaskCollection, + ntree_limit: Optional[int] = None, + iteration_range: Optional[Tuple[int, int]] = None, + ) -> Any: + _assert_dask_support() + msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead." + assert ntree_limit is None, msg + return self.client.sync(self._apply_async, X, iteration_range=iteration_range) + def __await__(self) -> Awaitable[Any]: # Generate a coroutine wrapper to make this class awaitable. async def _() -> Awaitable[Any]: return self + return self.client.sync(_).__await__() - def __getstate__(self): + def __getstate__(self) -> Dict: this = self.__dict__.copy() if "_client" in this.keys(): del this["_client"] @@ -1404,7 +1483,7 @@ class DaskScikitLearnBase(XGBModel): @property def client(self) -> "distributed.Client": - '''The dask client used in this model.''' + """The dask client used in this model.""" client = _xgb_get_client(self._client) return client @@ -1494,7 +1573,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase): sample_weight_eval_set: Optional[List[_DaskCollection]] = None, base_margin_eval_set: Optional[List[_DaskCollection]] = None, feature_weights: Optional[_DaskCollection] = None, - callbacks: Optional[List[TrainingCallback]] = None + callbacks: Optional[List[TrainingCallback]] = None, ) -> "DaskXGBRegressor": _assert_dask_support() args = {k: v for k, v in locals().items() if k != "self"} @@ -1556,9 +1635,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): else: obj = None model, metric, params = self._configure_fit( - booster=xgb_model, - eval_metric=eval_metric, - params=params + booster=xgb_model, eval_metric=eval_metric, params=params ) results = await train( client=self.client, @@ -1610,18 +1687,19 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): X: _DaskCollection, validate_features: bool, output_margin: bool, - base_margin: Optional[_DaskCollection] + base_margin: Optional[_DaskCollection], + iteration_range: Optional[Tuple[int, int]], ) -> _DaskCollection: - test_dmatrix = await DaskDMatrix( - client=self.client, data=X, base_margin=base_margin, - missing=self.missing + if iteration_range is None: + iteration_range = (0, 0) + predts = await super()._predict_async( + data=X, + output_margin=output_margin, + validate_features=validate_features, + base_margin=base_margin, + iteration_range=iteration_range, ) - pred_probs = await predict(client=self.client, - model=self.get_booster(), - data=test_dmatrix, - validate_features=validate_features, - output_margin=output_margin) - return _cls_predict_proba(self.objective, pred_probs, da.vstack) + return _cls_predict_proba(self.objective, predts, da.vstack) # pylint: disable=missing-function-docstring def predict_proba( @@ -1630,37 +1708,49 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): ntree_limit: Optional[int] = None, validate_features: bool = True, output_margin: bool = False, - base_margin: Optional[_DaskCollection] = None + base_margin: Optional[_DaskCollection] = None, + iteration_range: Optional[Tuple[int, int]] = None, ) -> Any: _assert_dask_support() - msg = '`ntree_limit` is not supported on dask, use model slicing instead.' + msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead." assert ntree_limit is None, msg return self.client.sync( self._predict_proba_async, X=X, validate_features=validate_features, output_margin=output_margin, - base_margin=base_margin + base_margin=base_margin, + iteration_range=iteration_range, ) + predict_proba.__doc__ = XGBClassifier.predict_proba.__doc__ async def _predict_async( - self, data: _DaskCollection, - output_margin: bool = False, - validate_features: bool = True, - base_margin: Optional[_DaskCollection] = None + self, + data: _DaskCollection, + output_margin: bool, + validate_features: bool, + base_margin: Optional[_DaskCollection], + iteration_range: Optional[Tuple[int, int]], ) -> _DaskCollection: pred_probs = await super()._predict_async( - data, output_margin, validate_features, base_margin + data, output_margin, validate_features, base_margin, iteration_range ) if output_margin: return pred_probs - if self.n_classes_ == 2: + if len(pred_probs.shape) == 1: preds = (pred_probs > 0.5).astype(int) else: - preds = da.argmax(pred_probs, axis=1) + assert len(pred_probs.shape) == 2 + assert isinstance(pred_probs, da.Array) + # when using da.argmax directly, dask will construct a numpy based return + # array, which runs into error when computing GPU based prediction. + def _argmax(x: Any) -> Any: + return x.argmax(axis=1) + + preds = da.map_blocks(_argmax, pred_probs, drop_axis=1) return preds @@ -1770,7 +1860,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn): callbacks: Optional[List[TrainingCallback]] = None ) -> "DaskXGBRanker": _assert_dask_support() - args = {k: v for k, v in locals().items() if k != 'self'} + args = {k: v for k, v in locals().items() if k != "self"} return self.client.sync(self._fit_async, **args) # FIXME(trivialfis): arguments differ due to additional parameters like group and qid. diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 84632f4aa..fcb18319d 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -6,7 +6,8 @@ import warnings import json from typing import Union, Optional, List, Dict, Callable, Tuple, Any import numpy as np -from .core import Booster, DMatrix, XGBoostError, _deprecate_positional_args +from .core import Booster, DMatrix, XGBoostError +from .core import _deprecate_positional_args, _convert_ntree_limit from .core import Metric from .training import train from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array @@ -413,8 +414,8 @@ class XGBModel(XGBModelBase): # Simple optimization to gain speed (inspect is slow) return self - # this concatenates kwargs into paraemters, enabling `get_params` for - # obtaining parameters from keyword paraemters. + # this concatenates kwargs into parameters, enabling `get_params` for + # obtaining parameters from keyword parameters. for key, value in params.items(): if hasattr(self, key): setattr(self, key, value) @@ -747,26 +748,45 @@ class XGBModel(XGBModelBase): self._set_evaluation_result(evals_result) return self + def _can_use_inplace_predict(self) -> bool: + # When predictor is explicitly set, using `inplace_predict` might result into + # error with incompatible data type. + # Inplace predict doesn't handle as many data types as DMatrix, but it's + # sufficient for dask interface where input is simpiler. + params = self.get_params() + booster = self.booster + if params.get("predictor", None) is None and ( + booster is None or booster == "gbtree" + ): + return True + return False + + def _get_iteration_range( + self, iteration_range: Optional[Tuple[int, int]] + ) -> Tuple[int, int]: + if (iteration_range is None or iteration_range[1] == 0): + # Use best_iteration if defined. + try: + iteration_range = (0, self.best_iteration + 1) + except AttributeError: + iteration_range = (0, 0) + if self.booster == "gblinear": + iteration_range = (0, 0) + return iteration_range + def predict( self, X, output_margin=False, ntree_limit=None, validate_features=True, - base_margin=None + base_margin=None, + iteration_range=None, ): """ Predict with `X`. - .. note:: This function is not thread safe. - - For each booster object, predict can only be called from one thread. - If you want to run prediction using multiple thread, call ``xgb.copy()`` to make copies - of model object and then call ``predict()``. - - .. code-block:: python - - preds = bst.predict(dtest, ntree_limit=num_round) + .. note:: This function is only thread safe for `gbtree` Parameters ---------- @@ -775,37 +795,40 @@ class XGBModel(XGBModelBase): output_margin : bool Whether to output the raw untransformed margin value. ntree_limit : int - Limit number of trees in the prediction; defaults to best_ntree_limit if - defined (i.e. it has been trained with early stopping), otherwise 0 (use all - trees). + Deprecated, use `iteration_range` instead. validate_features : bool - When this is True, validate that the Booster's and data's feature_names are identical. - Otherwise, it is assumed that the feature_names are the same. + When this is True, validate that the Booster's and data's feature_names are + identical. Otherwise, it is assumed that the feature_names are the same. base_margin : array_like Margin added to prediction. + iteration_range : + Specifies which layer of trees are used in prediction. For example, if a + random forest is trained with 100 rounds. Specifying `iteration_range=(10, + 20)`, then only the forests built during [10, 20) (half open set) rounds are + used in this prediction. + .. versionadded:: 1.4.0 Returns ------- prediction : numpy array """ - # pylint: disable=missing-docstring,invalid-name - test_dmatrix = DMatrix(X, base_margin=base_margin, - missing=self.missing, nthread=self.n_jobs) - # get ntree_limit to use - if none specified, default to - # best_ntree_limit if defined, otherwise 0. - if ntree_limit is None: - try: - ntree_limit = self.best_ntree_limit - except AttributeError: - ntree_limit = 0 + iteration_range = _convert_ntree_limit( + self.get_booster(), ntree_limit, iteration_range + ) + iteration_range = self._get_iteration_range(iteration_range) + test = DMatrix( + X, base_margin=base_margin, missing=self.missing, nthread=self.n_jobs + ) return self.get_booster().predict( - test_dmatrix, + data=test, + iteration_range=iteration_range, output_margin=output_margin, - ntree_limit=ntree_limit, - validate_features=validate_features + validate_features=validate_features, ) - def apply(self, X, ntree_limit=0): + def apply( + self, X, ntree_limit: int = 0, iteration_range: Optional[Tuple[int, int]] = None + ) -> np.ndarray: """Return the predicted leaf every tree for each sample. Parameters @@ -823,10 +846,16 @@ class XGBModel(XGBModelBase): leaf x ends up in. Leaves are numbered within ``[0; 2**(self.max_depth+1))``, possibly with gaps in the numbering. """ + iteration_range = _convert_ntree_limit( + self.get_booster(), ntree_limit, iteration_range + ) + iteration_range = self._get_iteration_range(iteration_range) test_dmatrix = DMatrix(X, missing=self.missing, nthread=self.n_jobs) - return self.get_booster().predict(test_dmatrix, - pred_leaf=True, - ntree_limit=ntree_limit) + return self.get_booster().predict( + test_dmatrix, + pred_leaf=True, + iteration_range=iteration_range + ) def evals_result(self): """Return the evaluation results. @@ -945,8 +974,7 @@ class XGBModel(XGBModelBase): 'Coefficients are not defined for Booster type {}' .format(self.booster)) b = self.get_booster() - coef = np.array(json.loads( - b.get_dump(dump_format='json')[0])['weight']) + coef = np.array(json.loads(b.get_dump(dump_format='json')[0])['weight']) # Logic for multiclass classification n_classes = getattr(self, 'n_classes_', None) if n_classes is not None: @@ -1157,14 +1185,16 @@ class XGBClassifier(XGBModel, XGBClassifierBase): output_margin=False, ntree_limit=None, validate_features=True, - base_margin=None + base_margin=None, + iteration_range: Optional[Tuple[int, int]] = None, ): class_probs = super().predict( X=X, output_margin=output_margin, ntree_limit=ntree_limit, validate_features=validate_features, - base_margin=base_margin + base_margin=base_margin, + iteration_range=iteration_range, ) if output_margin: # If output_margin is active, simply return the scores @@ -1180,29 +1210,34 @@ class XGBClassifier(XGBModel, XGBClassifierBase): return self._le.inverse_transform(column_indexes) return column_indexes - def predict_proba(self, X, ntree_limit=None, validate_features=False, - base_margin=None): + def predict_proba( + self, + X, + ntree_limit=None, + validate_features=False, + base_margin=None, + iteration_range: Optional[Tuple[int, int]] = None, + ) -> np.ndarray: """ Predict the probability of each `X` example being of a given class. - .. note:: This function is not thread safe - - For each booster object, predict can only be called from one - thread. If you want to run prediction using multiple thread, call - ``xgb.copy()`` to make copies of model object and then call predict + .. note:: This function is only thread safe for `gbtree` Parameters ---------- X : array_like Feature matrix. ntree_limit : int - Limit number of trees in the prediction; defaults to best_ntree_limit if - defined (i.e. it has been trained with early stopping), otherwise 0 (use all - trees). + Deprecated, use `iteration_range` instead. validate_features : bool When this is True, validate that the Booster's and data's feature_names are identical. Otherwise, it is assumed that the feature_names are the same. base_margin : array_like Margin added to prediction. + iteration_range : + Specifies which layer of trees are used in prediction. For example, if a + random forest is trained with 100 rounds. Specifying `iteration_range=(10, + 20)`, then only the forests built during [10, 20) (half open set) rounds are + used in this prediction. Returns ------- @@ -1215,7 +1250,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase): output_margin=False, ntree_limit=ntree_limit, validate_features=validate_features, - base_margin=base_margin + base_margin=base_margin, + iteration_range=iteration_range ) return _cls_predict_proba(self.objective, class_probs, np.vstack) diff --git a/python-package/xgboost/training.py b/python-package/xgboost/training.py index b3290a9e5..ef488baa6 100644 --- a/python-package/xgboost/training.py +++ b/python-package/xgboost/training.py @@ -4,9 +4,8 @@ """Training Library containing training routines.""" import warnings import copy -import json import numpy as np -from .core import Booster, XGBoostError +from .core import Booster, XGBoostError, _get_booster_layer_trees from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold) from . import callback @@ -91,24 +90,7 @@ def _train_internal(params, dtrain, # These should be moved into callback functions `after_training`, but until old # callbacks are removed, the train function is the only place for setting the # attributes. - config = json.loads(bst.save_config()) - booster = config['learner']['gradient_booster']['name'] - if booster == 'gblinear': - num_parallel_tree = 0 - elif booster == 'dart': - num_parallel_tree = int( - config['learner']['gradient_booster']['gbtree']['gbtree_train_param'][ - 'num_parallel_tree' - ] - ) - elif booster == 'gbtree': - num_parallel_tree = int( - config['learner']['gradient_booster']['gbtree_train_param'][ - 'num_parallel_tree'] - ) - else: - raise ValueError(f'Unknown booster: {booster}') - + num_parallel_tree, _ = _get_booster_layer_trees(bst) if bst.attr('best_score') is not None: bst.best_score = float(bst.attr('best_score')) bst.best_iteration = int(bst.attr('best_iteration')) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index a131d965e..c81fddf2c 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -619,20 +619,58 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle, CHECK_HANDLE(); auto *learner = static_cast(handle); auto& entry = learner->GetThreadLocal().prediction_entry; - learner->Predict( - *static_cast*>(dmat), - (option_mask & 1) != 0, - &entry.predictions, ntree_limit, - static_cast(training), - (option_mask & 2) != 0, - (option_mask & 4) != 0, - (option_mask & 8) != 0, - (option_mask & 16) != 0); + auto iteration_end = GetIterationFromTreeLimit(ntree_limit, learner); + learner->Predict(*static_cast *>(dmat), + (option_mask & 1) != 0, &entry.predictions, 0, iteration_end, + static_cast(training), (option_mask & 2) != 0, + (option_mask & 4) != 0, (option_mask & 8) != 0, + (option_mask & 16) != 0); *out_result = dmlc::BeginPtr(entry.predictions.ConstHostVector()); *len = static_cast(entry.predictions.Size()); API_END(); } +XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle, + DMatrixHandle dmat, + char const* c_json_config, + xgboost::bst_ulong const **out_shape, + xgboost::bst_ulong *out_dim, + bst_float const **out_result) { + API_BEGIN(); + if (handle == nullptr) { + LOG(FATAL) << "Booster has not been intialized or has already been disposed."; + } + if (dmat == nullptr) { + LOG(FATAL) << "DMatrix has not been intialized or has already been disposed."; + } + auto config = Json::Load(StringView{c_json_config}); + + auto *learner = static_cast(handle); + auto& entry = learner->GetThreadLocal().prediction_entry; + auto p_m = *static_cast *>(dmat); + auto type = PredictionType(get(config["type"])); + auto iteration_begin = get(config["iteration_begin"]); + auto iteration_end = get(config["iteration_end"]); + learner->Predict( + *static_cast *>(dmat), + type == PredictionType::kMargin, &entry.predictions, iteration_begin, + iteration_end, get(config["training"]), + type == PredictionType::kLeaf, type == PredictionType::kContribution, + type == PredictionType::kApproxContribution, + type == PredictionType::kInteraction); + *out_result = dmlc::BeginPtr(entry.predictions.ConstHostVector()); + auto &shape = learner->GetThreadLocal().prediction_shape; + auto chunksize = p_m->Info().num_row_ == 0 ? 0 : entry.predictions.Size() / p_m->Info().num_row_; + auto rounds = iteration_end - iteration_begin; + rounds = rounds == 0 ? learner->BoostedRounds() : rounds; + // Determine shape + bool strict_shape = get(config["strict_shape"]); + CalcPredictShape(strict_shape, type, p_m->Info().num_row_, + p_m->Info().num_col_, chunksize, learner->Groups(), rounds, + &shape, out_dim); + *out_shape = dmlc::BeginPtr(shape); + API_END(); +} template void InplacePredictImpl(std::shared_ptr x, std::shared_ptr p_m, @@ -705,7 +743,7 @@ XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, char const *indptr, } #if !defined(XGBOOST_USE_CUDA) -XGB_DLL int XGBoosterPredictFromArrayInterface( +XGB_DLL int XGBoosterPredictFromCUDAArray( BoosterHandle handle, char const *c_json_strs, char const *c_json_config, DMatrixHandle m, xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim, const float **out_result) { @@ -715,7 +753,7 @@ XGB_DLL int XGBoosterPredictFromArrayInterface( API_END(); } -XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns( +XGB_DLL int XGBoosterPredictFromCUDAColumnar( BoosterHandle handle, char const *c_json_strs, char const *c_json_config, DMatrixHandle m, xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim, const float **out_result) { diff --git a/src/c_api/c_api.cu b/src/c_api/c_api.cu index baa4d8fdf..406190ed6 100644 --- a/src/c_api/c_api.cu +++ b/src/c_api/c_api.cu @@ -66,8 +66,7 @@ int InplacePreidctCuda(BoosterHandle handle, char const *c_json_strs, API_END(); } -// A hidden API as cache id is not being supported yet. -XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns( +XGB_DLL int XGBoosterPredictFromCudaColumnar( BoosterHandle handle, char const *c_json_strs, char const *c_json_config, DMatrixHandle m, xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim, const float **out_result) { @@ -79,8 +78,7 @@ XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns( handle, c_json_strs, c_json_config, p_m, out_shape, out_dim, out_result); } -// A hidden API as cache id is not being supported yet. -XGB_DLL int XGBoosterPredictFromArrayInterface( +XGB_DLL int XGBoosterPredictFromCudaArray( BoosterHandle handle, char const *c_json_strs, char const *c_json_config, DMatrixHandle m, xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim, const float **out_result) { diff --git a/src/c_api/c_api_utils.h b/src/c_api/c_api_utils.h index 0ebc0b8be..2e3af08cb 100644 --- a/src/c_api/c_api_utils.h +++ b/src/c_api/c_api_utils.h @@ -9,6 +9,7 @@ #include #include "xgboost/logging.h" +#include "xgboost/json.h" #include "xgboost/learner.h" namespace xgboost { @@ -30,8 +31,8 @@ inline void CalcPredictShape(bool strict_shape, PredictionType type, size_t rows std::vector *out_shape, xgboost::bst_ulong *out_dim) { auto &shape = *out_shape; - if ((type == PredictionType::kMargin || type == PredictionType::kValue) && - rows != 0) { + if (type == PredictionType::kMargin && rows != 0) { + // When kValue is used, softmax can change the chunksize. CHECK_EQ(chunksize, groups); } @@ -110,5 +111,35 @@ inline void CalcPredictShape(bool strict_shape, PredictionType type, size_t rows std::accumulate(shape.cbegin(), shape.cend(), 1, std::multiplies<>{}), chunksize * rows); } + +// Reverse the ntree_limit in old prediction API. +inline uint32_t GetIterationFromTreeLimit(uint32_t ntree_limit, Learner *learner) { + // On Python and R, `best_ntree_limit` is set to `best_iteration * num_parallel_tree`. + // To reverse it we just divide it by `num_parallel_tree`. + if (ntree_limit != 0) { + learner->Configure(); + uint32_t num_parallel_tree = 0; + + Json config{Object()}; + learner->SaveConfig(&config); + auto const &booster = + get(config["learner"]["gradient_booster"]["name"]); + if (booster == "gblinear") { + num_parallel_tree = 0; + } else if (booster == "dart") { + num_parallel_tree = std::stoi( + get(config["learner"]["gradient_booster"]["gbtree"] + ["gbtree_train_param"]["num_parallel_tree"])); + } else if (booster == "gbtree") { + num_parallel_tree = std::stoi(get( + (config["learner"]["gradient_booster"]["gbtree_train_param"] + ["num_parallel_tree"]))); + } else { + LOG(FATAL) << "Unknown booster:" << booster; + } + ntree_limit /= std::max(num_parallel_tree, 1u); + } + return ntree_limit; +} } // namespace xgboost #endif // XGBOOST_C_API_C_API_UTILS_H_ diff --git a/src/cli_main.cc b/src/cli_main.cc index 3b9f8267c..82140c9ca 100644 --- a/src/cli_main.cc +++ b/src/cli_main.cc @@ -25,6 +25,7 @@ #include "common/config.h" #include "common/io.h" #include "common/version.h" +#include "c_api/c_api_utils.h" namespace xgboost { enum CLITask { @@ -58,6 +59,8 @@ struct CLIParam : public XGBoostParameter { int dsplit; /*!\brief limit number of trees in prediction */ int ntree_limit; + int iteration_begin; + int iteration_end; /*!\brief whether to directly output margin value */ bool pred_margin; /*! \brief whether dump statistics along with model */ @@ -109,7 +112,11 @@ struct CLIParam : public XGBoostParameter { .add_enum("row", 2) .describe("Data split mode."); DMLC_DECLARE_FIELD(ntree_limit).set_default(0).set_lower_bound(0) - .describe("Number of trees used for prediction, 0 means use all trees."); + .describe("(Deprecated) Use iteration_begin/iteration_end instead."); + DMLC_DECLARE_FIELD(iteration_begin).set_default(0).set_lower_bound(0) + .describe("Begining of boosted tree iteration used for prediction."); + DMLC_DECLARE_FIELD(iteration_end).set_default(0).set_lower_bound(0) + .describe("End of boosted tree iteration used for prediction. 0 means all the trees."); DMLC_DECLARE_FIELD(pred_margin).set_default(false) .describe("Whether to predict margin value instead of probability."); DMLC_DECLARE_FIELD(dump_stats).set_default(false) @@ -334,7 +341,13 @@ class CLI { LOG(INFO) << "Start prediction..."; HostDeviceVector preds; - learner_->Predict(dtest, param_.pred_margin, &preds, param_.ntree_limit); + if (param_.ntree_limit != 0) { + param_.iteration_end = GetIterationFromTreeLimit(param_.ntree_limit, learner_.get()); + LOG(WARNING) << "`ntree_limit` is deprecated, use `iteration_begin` and " + "`iteration_end` instead."; + } + learner_->Predict(dtest, param_.pred_margin, &preds, param_.iteration_begin, + param_.iteration_end); LOG(CONSOLE) << "Writing prediction to " << param_.name_pred; std::unique_ptr fo( diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index cc2a4a439..e9c2f27aa 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -47,6 +47,12 @@ struct GBLinearTrainParam : public XGBoostParameter { .describe("Maximum rows per batch."); } }; + +void LinearCheckLayer(unsigned layer_begin, unsigned layer_end) { + CHECK_EQ(layer_begin, 0) << "Linear booster does not support prediction range."; + CHECK_EQ(layer_end, 0) << "Linear booster does not support prediction range."; +} + /*! * \brief gradient boosted linear model */ @@ -130,20 +136,19 @@ class GBLinear : public GradientBooster { monitor_.Stop("DoBoost"); } - void PredictBatch(DMatrix *p_fmat, - PredictionCacheEntry *predts, - bool, unsigned ntree_limit) override { + void PredictBatch(DMatrix *p_fmat, PredictionCacheEntry *predts, + bool training, unsigned layer_begin, unsigned layer_end) override { monitor_.Start("PredictBatch"); + LinearCheckLayer(layer_begin, layer_end); auto* out_preds = &predts->predictions; - CHECK_EQ(ntree_limit, 0U) - << "GBLinear::Predict ntrees is only valid for gbtree predictor"; this->PredictBatchInternal(p_fmat, &out_preds->HostVector()); monitor_.Stop("PredictBatch"); } // add base margin void PredictInstance(const SparsePage::Inst &inst, std::vector *out_preds, - unsigned) override { + unsigned layer_begin, unsigned layer_end) override { + LinearCheckLayer(layer_begin, layer_end); const int ngroup = model_.learner_model_param->num_output_group; for (int gid = 0; gid < ngroup; ++gid) { this->Pred(inst, dmlc::BeginPtr(*out_preds), gid, @@ -151,16 +156,15 @@ class GBLinear : public GradientBooster { } } - void PredictLeaf(DMatrix *, HostDeviceVector *, unsigned) override { + void PredictLeaf(DMatrix *, HostDeviceVector *, unsigned, unsigned) override { LOG(FATAL) << "gblinear does not support prediction of leaf index"; } void PredictContribution(DMatrix* p_fmat, HostDeviceVector* out_contribs, - unsigned ntree_limit, bool, int, unsigned) override { + unsigned layer_begin, unsigned layer_end, bool, int, unsigned) override { model_.LazyInitModel(); - CHECK_EQ(ntree_limit, 0U) - << "GBLinear::PredictContribution: ntrees is only valid for gbtree predictor"; + LinearCheckLayer(layer_begin, layer_end); const auto& base_margin = p_fmat->Info().base_margin_.ConstHostVector(); const int ngroup = model_.learner_model_param->num_output_group; const size_t ncolumns = model_.learner_model_param->num_feature + 1; @@ -197,7 +201,8 @@ class GBLinear : public GradientBooster { void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector* out_contribs, - unsigned, bool) override { + unsigned layer_begin, unsigned layer_end, bool) override { + LinearCheckLayer(layer_begin, layer_end); std::vector& contribs = out_contribs->HostVector(); // linear models have no interaction effects diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 1706842e2..30732dbd8 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -414,7 +414,7 @@ void GBTree::Slice(int32_t layer_begin, int32_t layer_end, int32_t step, auto layer_trees = this->LayerTrees(); layer_end = layer_end == 0 ? model_.trees.size() / layer_trees : layer_end; - CHECK_GE(layer_end, layer_begin); + CHECK_GT(layer_end, layer_begin); CHECK_GE(step, 1); int32_t n_layers = (layer_end - layer_begin) / step; std::vector> &out_trees = out_model.trees; @@ -438,10 +438,35 @@ void GBTree::Slice(int32_t layer_begin, int32_t layer_end, int32_t step, void GBTree::PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool, - unsigned ntree_limit) { + unsigned layer_begin, + unsigned layer_end) { CHECK(configured_); + if (layer_end == 0) { + layer_end = this->BoostedRounds(); + } + if (layer_begin != 0 || layer_end < out_preds->version) { + // cache is dropped. + out_preds->version = 0; + } + bool reset = false; + if (layer_begin == 0) { + layer_begin = out_preds->version; + } else { + // When begin layer is not 0, the cache is not useful. + reset = true; + } + + uint32_t tree_begin, tree_end; + std::tie(tree_begin, tree_end) = + detail::LayerToTree(model_, tparam_, layer_begin, layer_end); GetPredictor(&out_preds->predictions, p_fmat) - ->PredictBatch(p_fmat, out_preds, model_, 0, ntree_limit); + ->PredictBatch(p_fmat, out_preds, model_, tree_begin, tree_end); + if (reset) { + out_preds->version = 0; + } else { + uint32_t delta = layer_end - out_preds->version; + out_preds->Update(delta); + } } std::unique_ptr const & @@ -603,13 +628,14 @@ class Dart : public GBTree { void PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* p_out_preds, bool training, - unsigned ntree_limit) override { + unsigned layer_begin, + unsigned layer_end) override { DropTrees(training); int num_group = model_.learner_model_param->num_output_group; - ntree_limit *= num_group; - if (ntree_limit == 0 || ntree_limit > model_.trees.size()) { - ntree_limit = static_cast(model_.trees.size()); - } + uint32_t tree_begin, tree_end; + std::tie(tree_begin, tree_end) = + detail::LayerToTree(model_, tparam_, layer_begin, layer_end); + size_t n = num_group * p_fmat->Info().num_row_; const auto &base_margin = p_fmat->Info().base_margin_.ConstHostVector(); auto& out_preds = p_out_preds->predictions.HostVector(); @@ -623,26 +649,24 @@ class Dart : public GBTree { } const int nthread = omp_get_max_threads(); InitThreadTemp(nthread); - PredLoopSpecalize(p_fmat, &out_preds, num_group, 0, ntree_limit); + PredLoopSpecalize(p_fmat, &out_preds, num_group, tree_begin, tree_end); } void PredictInstance(const SparsePage::Inst &inst, std::vector *out_preds, - unsigned ntree_limit) override { + unsigned layer_begin, unsigned layer_end) override { DropTrees(false); if (thread_temp_.size() == 0) { thread_temp_.resize(1, RegTree::FVec()); thread_temp_[0].Init(model_.learner_model_param->num_feature); } out_preds->resize(model_.learner_model_param->num_output_group); - ntree_limit *= model_.learner_model_param->num_output_group; - if (ntree_limit == 0 || ntree_limit > model_.trees.size()) { - ntree_limit = static_cast(model_.trees.size()); - } + uint32_t tree_begin, tree_end; + std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end); // loop over output groups for (uint32_t gid = 0; gid < model_.learner_model_param->num_output_group; ++gid) { (*out_preds)[gid] = - PredValue(inst, gid, &thread_temp_[0], 0, ntree_limit) + + PredValue(inst, gid, &thread_temp_[0], 0, tree_end) + model_.learner_model_param->base_score; } } @@ -653,22 +677,25 @@ class Dart : public GBTree { void PredictContribution(DMatrix* p_fmat, HostDeviceVector* out_contribs, - unsigned ntree_limit, bool approximate, int, + unsigned layer_begin, unsigned layer_end, bool approximate, int, unsigned) override { CHECK(configured_); + uint32_t tree_begin, tree_end; + std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end); cpu_predictor_->PredictContribution(p_fmat, out_contribs, model_, - ntree_limit, &weight_drop_, approximate); + tree_end, &weight_drop_, approximate); } - void PredictInteractionContributions(DMatrix* p_fmat, - HostDeviceVector* out_contribs, - unsigned ntree_limit, bool approximate) override { + void PredictInteractionContributions( + DMatrix *p_fmat, HostDeviceVector *out_contribs, + unsigned layer_begin, unsigned layer_end, bool approximate) override { CHECK(configured_); + uint32_t tree_begin, tree_end; + std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end); cpu_predictor_->PredictInteractionContributions(p_fmat, out_contribs, model_, - ntree_limit, &weight_drop_, approximate); + tree_end, &weight_drop_, approximate); } - protected: inline void PredLoopSpecalize( DMatrix* p_fmat, diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index 059804e58..e4a0a53f6 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -164,7 +164,9 @@ inline std::pair LayerToTree(gbm::GBTreeModel const &model, if (tree_end == 0) { tree_end = static_cast(model.trees.size()); } - CHECK_LT(tree_begin, tree_end); + if (model.trees.size() != 0) { + CHECK_LE(tree_begin, tree_end); + } return {tree_begin, tree_end}; } @@ -260,10 +262,8 @@ class GBTree : public GradientBooster { return model_.trees.size() / this->LayerTrees(); } - void PredictBatch(DMatrix* p_fmat, - PredictionCacheEntry* out_preds, - bool training, - unsigned ntree_limit) override; + void PredictBatch(DMatrix *p_fmat, PredictionCacheEntry *out_preds, + bool training, unsigned layer_begin, unsigned layer_end) override; void InplacePredict(dmlc::any const &x, std::shared_ptr p_m, float missing, PredictionCacheEntry *out_preds, @@ -297,33 +297,49 @@ class GBTree : public GradientBooster { void PredictInstance(const SparsePage::Inst& inst, std::vector* out_preds, - unsigned ntree_limit) override { + uint32_t layer_begin, uint32_t layer_end) override { CHECK(configured_); + uint32_t tree_begin, tree_end; + std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end); cpu_predictor_->PredictInstance(inst, out_preds, model_, - ntree_limit); + tree_end); } void PredictLeaf(DMatrix* p_fmat, HostDeviceVector* out_preds, - unsigned ntree_limit) override { - this->GetPredictor()->PredictLeaf(p_fmat, out_preds, model_, ntree_limit); + uint32_t layer_begin, uint32_t layer_end) override { + uint32_t tree_begin, tree_end; + std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end); + CHECK_EQ(tree_begin, 0) << "Predict leaf supports only iteration end: (0, " + "n_iteration), use model slicing instead."; + this->GetPredictor()->PredictLeaf(p_fmat, out_preds, model_, tree_end); } void PredictContribution(DMatrix* p_fmat, HostDeviceVector* out_contribs, - unsigned ntree_limit, bool approximate, + uint32_t layer_begin, uint32_t layer_end, bool approximate, int, unsigned) override { CHECK(configured_); + uint32_t tree_begin, tree_end; + std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end); + CHECK_EQ(tree_begin, 0) + << "Predict contribution supports only iteration end: (0, " + "n_iteration), using model slicing instead."; this->GetPredictor()->PredictContribution( - p_fmat, out_contribs, model_, ntree_limit, nullptr, approximate); + p_fmat, out_contribs, model_, tree_end, nullptr, approximate); } - void PredictInteractionContributions(DMatrix* p_fmat, - HostDeviceVector* out_contribs, - unsigned ntree_limit, bool approximate) override { + void PredictInteractionContributions( + DMatrix *p_fmat, HostDeviceVector *out_contribs, + uint32_t layer_begin, uint32_t layer_end, bool approximate) override { CHECK(configured_); - this->GetPredictor()->PredictInteractionContributions(p_fmat, out_contribs, model_, - ntree_limit, nullptr, approximate); + uint32_t tree_begin, tree_end; + std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end); + CHECK_EQ(tree_begin, 0) + << "Predict interaction contribution supports only iteration end: (0, " + "n_iteration), using model slicing instead."; + this->GetPredictor()->PredictInteractionContributions( + p_fmat, out_contribs, model_, tree_end, nullptr, approximate); } std::vector DumpModel(const FeatureMap& fmap, diff --git a/src/learner.cc b/src/learner.cc index b4b7f89d9..c9e816c13 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -22,6 +22,7 @@ #include "dmlc/any.h" #include "xgboost/base.h" +#include "xgboost/c_api.h" #include "xgboost/data.h" #include "xgboost/model.h" #include "xgboost/predictor.h" @@ -996,7 +997,7 @@ class LearnerImpl : public LearnerIO { auto& predt = local_cache->Cache(train, generic_parameters_.gpu_id); monitor_.Start("PredictRaw"); - this->PredictRaw(train.get(), &predt, true); + this->PredictRaw(train.get(), &predt, true, 0, 0); TrainingObserver::Instance().Observe(predt.predictions, "Predictions"); monitor_.Stop("PredictRaw"); @@ -1057,7 +1058,7 @@ class LearnerImpl : public LearnerIO { std::shared_ptr m = data_sets[i]; auto &predt = local_cache->Cache(m, generic_parameters_.gpu_id); this->ValidateDMatrix(m.get(), false); - this->PredictRaw(m.get(), &predt, false); + this->PredictRaw(m.get(), &predt, false, 0, 0); auto &out = output_predictions_.Cache(m, generic_parameters_.gpu_id).predictions; out.Resize(predt.predictions.Size()); @@ -1075,8 +1076,8 @@ class LearnerImpl : public LearnerIO { } void Predict(std::shared_ptr data, bool output_margin, - HostDeviceVector* out_preds, unsigned ntree_limit, - bool training, + HostDeviceVector *out_preds, unsigned layer_begin, + unsigned layer_end, bool training, bool pred_leaf, bool pred_contribs, bool approx_contribs, bool pred_interactions) override { int multiple_predictions = static_cast(pred_leaf) + @@ -1085,16 +1086,16 @@ class LearnerImpl : public LearnerIO { this->Configure(); CHECK_LE(multiple_predictions, 1) << "Perform one kind of prediction at a time."; if (pred_contribs) { - gbm_->PredictContribution(data.get(), out_preds, ntree_limit, approx_contribs); + gbm_->PredictContribution(data.get(), out_preds, layer_begin, layer_end, approx_contribs); } else if (pred_interactions) { - gbm_->PredictInteractionContributions(data.get(), out_preds, ntree_limit, + gbm_->PredictInteractionContributions(data.get(), out_preds, layer_begin, layer_end, approx_contribs); } else if (pred_leaf) { - gbm_->PredictLeaf(data.get(), out_preds, ntree_limit); + gbm_->PredictLeaf(data.get(), out_preds, layer_begin, layer_end); } else { auto local_cache = this->GetPredictionCache(); auto& prediction = local_cache->Cache(data, generic_parameters_.gpu_id); - this->PredictRaw(data.get(), &prediction, training, ntree_limit); + this->PredictRaw(data.get(), &prediction, training, layer_begin, layer_end); // Copy the prediction cache to output prediction. out_preds comes from C API out_preds->SetDevice(generic_parameters_.gpu_id); out_preds->Resize(prediction.predictions.Size()); @@ -1151,12 +1152,11 @@ class LearnerImpl : public LearnerIO { * predictor, when it equals 0, this means we are using all the trees * \param training allow dropout when the DART booster is being used */ - void PredictRaw(DMatrix* data, PredictionCacheEntry* out_preds, - bool training, - unsigned ntree_limit = 0) const { + void PredictRaw(DMatrix *data, PredictionCacheEntry *out_preds, bool training, + unsigned layer_begin, unsigned layer_end) const { CHECK(gbm_ != nullptr) << "Predict must happen after Load or configuration"; this->ValidateDMatrix(data, false); - gbm_->PredictBatch(data, out_preds, training, ntree_limit); + gbm_->PredictBatch(data, out_preds, training, layer_begin, layer_end); } void ValidateDMatrix(DMatrix* p_fmat, bool is_training) const { diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 9fdf925db..338f24afc 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -234,56 +234,28 @@ class CPUPredictor : public Predictor { public: explicit CPUPredictor(GenericParameter const* generic_param) : Predictor::Predictor{generic_param} {} - // ntree_limit is a very problematic parameter, as it's ambiguous in the context of - // multi-output and forest. Same problem exists for tree_begin - void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts, - const gbm::GBTreeModel& model, int tree_begin, - uint32_t const ntree_limit = 0) const override { - // tree_begin is not used, right now we just enforce it to be 0. - CHECK_EQ(tree_begin, 0); + void PredictBatch(DMatrix *dmat, PredictionCacheEntry *predts, + const gbm::GBTreeModel &model, uint32_t tree_begin, + uint32_t tree_end = 0) const override { auto* out_preds = &predts->predictions; - CHECK_GE(predts->version, tree_begin); if (out_preds->Size() == 0 && dmat->Info().num_row_ != 0) { CHECK_EQ(predts->version, 0); } + // This is actually already handled in gbm, but large amount of tests rely on the + // behaviour. + if (tree_end == 0) { + tree_end = model.trees.size(); + } if (predts->version == 0) { // out_preds->Size() can be non-zero as it's initialized here before any tree is // built at the 0^th iterator. this->InitOutPredictions(dmat->Info(), out_preds, model); } - - uint32_t const output_groups = model.learner_model_param->num_output_group; - CHECK_NE(output_groups, 0); - // Right now we just assume ntree_limit provided by users means number of tree layers - // in the context of multi-output model - uint32_t real_ntree_limit = ntree_limit * output_groups; - if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) { - real_ntree_limit = static_cast(model.trees.size()); + if (tree_end - tree_begin == 0) { + return; } - - uint32_t const end_version = (tree_begin + real_ntree_limit) / output_groups; - // When users have provided ntree_limit, end_version can be lesser, cache is violated - if (predts->version > end_version) { - CHECK_NE(ntree_limit, 0); - this->InitOutPredictions(dmat->Info(), out_preds, model); - predts->version = 0; - } - uint32_t const beg_version = predts->version; - CHECK_LE(beg_version, end_version); - - if (beg_version < end_version) { - this->PredictDMatrix(dmat, &out_preds->HostVector(), model, - beg_version * output_groups, - end_version * output_groups); - } - - // delta means {size of forest} * {number of newly accumulated layers} - uint32_t delta = end_version - beg_version; - CHECK_LE(delta, model.trees.size()); - predts->Update(delta); - - CHECK(out_preds->Size() == output_groups * dmat->Info().num_row_ || - out_preds->Size() == dmat->Info().num_row_); + this->PredictDMatrix(dmat, &out_preds->HostVector(), model, tree_begin, + tree_end); } template @@ -362,7 +334,6 @@ class CPUPredictor : public Predictor { InitThreadTemp(nthread, model.learner_model_param->num_feature, &feat_vecs); const MetaInfo& info = p_fmat->Info(); // number of valid trees - ntree_limit *= model.learner_model_param->num_output_group; if (ntree_limit == 0 || ntree_limit > model.trees.size()) { ntree_limit = static_cast(model.trees.size()); } @@ -398,7 +369,6 @@ class CPUPredictor : public Predictor { InitThreadTemp(nthread, model.learner_model_param->num_feature, &feat_vecs); const MetaInfo& info = p_fmat->Info(); // number of valid trees - ntree_limit *= model.learner_model_param->num_output_group; if (ntree_limit == 0 || ntree_limit > model.trees.size()) { ntree_limit = static_cast(model.trees.size()); } diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index d786229c9..03b9e1652 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -536,6 +536,7 @@ class GPUPredictor : public xgboost::Predictor { const uint32_t BLOCK_THREADS = 256; size_t num_rows = batch.n_rows; auto GRID_SIZE = static_cast(common::DivRoundUp(num_rows, BLOCK_THREADS)); + DeviceModel d_model; bool use_shared = false; size_t entry_start = 0; @@ -593,54 +594,27 @@ class GPUPredictor : public xgboost::Predictor { } void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts, - const gbm::GBTreeModel& model, int tree_begin, - unsigned ntree_limit = 0) const override { - // This function is duplicated with CPU predictor PredictBatch, see comments in there. - // FIXME(trivialfis): Remove the duplication. + const gbm::GBTreeModel& model, uint32_t tree_begin, + uint32_t tree_end = 0) const override { int device = generic_param_->gpu_id; CHECK_GE(device, 0) << "Set `gpu_id' to positive value for processing GPU data."; - ConfigureDevice(device); - - CHECK_EQ(tree_begin, 0); auto* out_preds = &predts->predictions; - CHECK_GE(predts->version, tree_begin); + if (out_preds->Size() == 0 && dmat->Info().num_row_ != 0) { CHECK_EQ(predts->version, 0); } + if (tree_end == 0) { + tree_end = model.trees.size(); + } if (predts->version == 0) { + // out_preds->Size() can be non-zero as it's initialized here before any tree is + // built at the 0^th iterator. this->InitOutPredictions(dmat->Info(), out_preds, model); } - - uint32_t const output_groups = model.learner_model_param->num_output_group; - CHECK_NE(output_groups, 0); - - uint32_t real_ntree_limit = ntree_limit * output_groups; - if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) { - real_ntree_limit = static_cast(model.trees.size()); + if (tree_end - tree_begin == 0) { + return; } - - uint32_t const end_version = (tree_begin + real_ntree_limit) / output_groups; - - if (predts->version > end_version) { - CHECK_NE(ntree_limit, 0); - this->InitOutPredictions(dmat->Info(), out_preds, model); - predts->version = 0; - } - uint32_t const beg_version = predts->version; - CHECK_LE(beg_version, end_version); - - if (beg_version < end_version) { - this->DevicePredictInternal(dmat, out_preds, model, - beg_version * output_groups, - end_version * output_groups); - } - - uint32_t delta = end_version - beg_version; - CHECK_LE(delta, model.trees.size()); - predts->Update(delta); - - CHECK(out_preds->Size() == output_groups * dmat->Info().num_row_ || - out_preds->Size() == dmat->Info().num_row_); + this->DevicePredictInternal(dmat, out_preds, model, tree_begin, tree_end); } template @@ -648,15 +622,12 @@ class GPUPredictor : public xgboost::Predictor { const gbm::GBTreeModel &model, float, PredictionCacheEntry *out_preds, uint32_t tree_begin, uint32_t tree_end) const { - auto max_shared_memory_bytes = dh::MaxSharedMemory(this->generic_param_->gpu_id); uint32_t const output_groups = model.learner_model_param->num_output_group; - DeviceModel d_model; - d_model.Init(model, tree_begin, tree_end, this->generic_param_->gpu_id); auto m = dmlc::get>(x); CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature) << "Number of columns in data must equal to trained model."; - CHECK_EQ(this->generic_param_->gpu_id, m->DeviceIdx()) + CHECK_EQ(dh::CurrentDevice(), m->DeviceIdx()) << "XGBoost is running on device: " << this->generic_param_->gpu_id << ", " << "but data is on: " << m->DeviceIdx(); if (p_m) { @@ -667,12 +638,17 @@ class GPUPredictor : public xgboost::Predictor { info.num_row_ = m->NumRows(); this->InitOutPredictions(info, &(out_preds->predictions), model); } + out_preds->predictions.SetDevice(m->DeviceIdx()); const uint32_t BLOCK_THREADS = 128; auto GRID_SIZE = static_cast(common::DivRoundUp(m->NumRows(), BLOCK_THREADS)); + auto max_shared_memory_bytes = dh::MaxSharedMemory(m->DeviceIdx()); size_t shared_memory_bytes = SharedMemoryBytes(m->NumColumns(), max_shared_memory_bytes); + DeviceModel d_model; + d_model.Init(model, tree_begin, tree_end, m->DeviceIdx()); + bool use_shared = shared_memory_bytes != 0; size_t entry_start = 0; @@ -707,20 +683,17 @@ class GPUPredictor : public xgboost::Predictor { void PredictContribution(DMatrix* p_fmat, HostDeviceVector* out_contribs, - const gbm::GBTreeModel& model, unsigned ntree_limit, + const gbm::GBTreeModel& model, unsigned tree_end, std::vector*, bool approximate, int, unsigned) const override { if (approximate) { LOG(FATAL) << "Approximated contribution is not implemented in GPU Predictor."; } - dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id)); out_contribs->SetDevice(generic_param_->gpu_id); - uint32_t real_ntree_limit = - ntree_limit * model.learner_model_param->num_output_group; - if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) { - real_ntree_limit = static_cast(model.trees.size()); + if (tree_end == 0 || tree_end > model.trees.size()) { + tree_end = static_cast(model.trees.size()); } const int ngroup = model.learner_model_param->num_output_group; @@ -734,8 +707,7 @@ class GPUPredictor : public xgboost::Predictor { auto phis = out_contribs->DeviceSpan(); dh::device_vector device_paths; - ExtractPaths(&device_paths, model, real_ntree_limit, - generic_param_->gpu_id); + ExtractPaths(&device_paths, model, tree_end, generic_param_->gpu_id); for (auto& batch : p_fmat->GetBatches()) { batch.data.SetDevice(generic_param_->gpu_id); batch.offset.SetDevice(generic_param_->gpu_id); @@ -761,20 +733,17 @@ class GPUPredictor : public xgboost::Predictor { void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector* out_contribs, const gbm::GBTreeModel& model, - unsigned ntree_limit, + unsigned tree_end, std::vector*, bool approximate) const override { if (approximate) { LOG(FATAL) << "[Internal error]: " << __func__ << " approximate is not implemented in GPU Predictor."; } - dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id)); out_contribs->SetDevice(generic_param_->gpu_id); - uint32_t real_ntree_limit = - ntree_limit * model.learner_model_param->num_output_group; - if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) { - real_ntree_limit = static_cast(model.trees.size()); + if (tree_end == 0 || tree_end > model.trees.size()) { + tree_end = static_cast(model.trees.size()); } const int ngroup = model.learner_model_param->num_output_group; @@ -789,8 +758,7 @@ class GPUPredictor : public xgboost::Predictor { auto phis = out_contribs->DeviceSpan(); dh::device_vector device_paths; - ExtractPaths(&device_paths, model, real_ntree_limit, - generic_param_->gpu_id); + ExtractPaths(&device_paths, model, tree_end, generic_param_->gpu_id); for (auto& batch : p_fmat->GetBatches()) { batch.data.SetDevice(generic_param_->gpu_id); batch.offset.SetDevice(generic_param_->gpu_id); @@ -841,29 +809,28 @@ class GPUPredictor : public xgboost::Predictor { << " is not implemented in GPU Predictor."; } - void PredictLeaf(DMatrix* p_fmat, HostDeviceVector* predictions, - const gbm::GBTreeModel& model, - unsigned ntree_limit) const override { + void PredictLeaf(DMatrix *p_fmat, HostDeviceVector *predictions, + const gbm::GBTreeModel &model, + unsigned tree_end) const override { dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id)); auto max_shared_memory_bytes = ConfigureDevice(generic_param_->gpu_id); const MetaInfo& info = p_fmat->Info(); constexpr uint32_t kBlockThreads = 128; - size_t shared_memory_bytes = - SharedMemoryBytes(info.num_col_, max_shared_memory_bytes); + size_t shared_memory_bytes = SharedMemoryBytes( + info.num_col_, max_shared_memory_bytes); bool use_shared = shared_memory_bytes != 0; bst_feature_t num_features = info.num_col_; bst_row_t num_rows = info.num_row_; size_t entry_start = 0; - uint32_t real_ntree_limit = ntree_limit * model.learner_model_param->num_output_group; - if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) { - real_ntree_limit = static_cast(model.trees.size()); + if (tree_end == 0 || tree_end > model.trees.size()) { + tree_end = static_cast(model.trees.size()); } predictions->SetDevice(generic_param_->gpu_id); - predictions->Resize(num_rows * real_ntree_limit); + predictions->Resize(num_rows * tree_end); DeviceModel d_model; - d_model.Init(model, 0, real_ntree_limit, this->generic_param_->gpu_id); + d_model.Init(model, 0, tree_end, this->generic_param_->gpu_id); if (p_fmat->PageExists()) { for (auto const& batch : p_fmat->GetBatches()) { diff --git a/tests/ci_build/conda_env/cpu_test.yml b/tests/ci_build/conda_env/cpu_test.yml index 6158515ba..691b46644 100644 --- a/tests/ci_build/conda_env/cpu_test.yml +++ b/tests/ci_build/conda_env/cpu_test.yml @@ -34,6 +34,7 @@ dependencies: - llvmlite - pip: - shap + - ipython # required by shap at import time. - guzzle_sphinx_theme - datatable - modin[all] diff --git a/tests/cpp/gbm/test_gbtree.cc b/tests/cpp/gbm/test_gbtree.cc index 64a94e736..2fbbab27f 100644 --- a/tests/cpp/gbm/test_gbtree.cc +++ b/tests/cpp/gbm/test_gbtree.cc @@ -51,6 +51,53 @@ TEST(GBTree, SelectTreeMethod) { #endif // XGBOOST_USE_CUDA } +TEST(GBTree, PredictionCache) { + size_t constexpr kRows = 100, kCols = 10; + GenericParameter generic_param; + generic_param.UpdateAllowUnknown(Args{}); + LearnerModelParam mparam; + mparam.base_score = 0.5; + mparam.num_feature = kCols; + mparam.num_output_group = 1; + + std::unique_ptr p_gbm { + GradientBooster::Create("gbtree", &generic_param, &mparam)}; + auto& gbtree = dynamic_cast (*p_gbm); + + gbtree.Configure({{"tree_method", "hist"}}); + auto p_m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(); + auto gpair = GenerateRandomGradients(kRows); + PredictionCacheEntry out_predictions; + gbtree.DoBoost(p_m.get(), &gpair, &out_predictions); + + gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 0); + ASSERT_EQ(1, out_predictions.version); + std::vector first_iter = out_predictions.predictions.HostVector(); + // Add 1 more boosted round + gbtree.DoBoost(p_m.get(), &gpair, &out_predictions); + gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 0); + ASSERT_EQ(2, out_predictions.version); + // Update the cache for all rounds + out_predictions.version = 0; + gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 0); + ASSERT_EQ(2, out_predictions.version); + + gbtree.DoBoost(p_m.get(), &gpair, &out_predictions); + // drop the cache. + gbtree.PredictBatch(p_m.get(), &out_predictions, false, 1, 2); + ASSERT_EQ(0, out_predictions.version); + // half open set [1, 3) + gbtree.PredictBatch(p_m.get(), &out_predictions, false, 1, 3); + ASSERT_EQ(0, out_predictions.version); + // iteration end + gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 2); + ASSERT_EQ(2, out_predictions.version); + // restart the cache when end iteration is smaller than cache version + gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 1); + ASSERT_EQ(1, out_predictions.version); + ASSERT_EQ(out_predictions.predictions.HostVector(), first_iter); +} + TEST(GBTree, WrongUpdater) { size_t constexpr kRows = 17; size_t constexpr kCols = 15; diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index 634747991..c5ee0b2e2 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -32,7 +32,7 @@ TEST(CpuPredictor, Basic) { // Test predict batch PredictionCacheEntry out_predictions; cpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0); - ASSERT_EQ(model.trees.size(), out_predictions.version); + std::vector& out_predictions_h = out_predictions.predictions.HostVector(); for (size_t i = 0; i < out_predictions.predictions.Size(); i++) { ASSERT_EQ(out_predictions_h[i], 1.5); @@ -215,7 +215,7 @@ TEST(CpuPredictor, UpdatePredictionCache) { PredictionCacheEntry out_predictions; // perform fair prediction on the same input data, should be equal to cached result - gbm->PredictBatch(dmat.get(), &out_predictions, false, 0); + gbm->PredictBatch(dmat.get(), &out_predictions, false, 0, 0); std::vector &out_predictions_h = out_predictions.predictions.HostVector(); std::vector &predtion_cache_from_train = predtion_cache.predictions.HostVector(); diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index 4d4417c66..6d38aec29 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -45,7 +45,6 @@ TEST(GPUPredictor, Basic) { PredictionCacheEntry cpu_out_predictions; gpu_predictor->PredictBatch(dmat.get(), &gpu_out_predictions, model, 0); - ASSERT_EQ(model.trees.size(), gpu_out_predictions.version); cpu_predictor->PredictBatch(dmat.get(), &cpu_out_predictions, model, 0); std::vector& gpu_out_predictions_h = gpu_out_predictions.predictions.HostVector(); diff --git a/tests/cpp/predictor/test_predictor.cc b/tests/cpp/predictor/test_predictor.cc index 7be88fea0..8df9d72d2 100644 --- a/tests/cpp/predictor/test_predictor.cc +++ b/tests/cpp/predictor/test_predictor.cc @@ -64,10 +64,10 @@ void TestTrainingPrediction(size_t rows, size_t bins, } HostDeviceVector from_full; - learner->Predict(p_full, false, &from_full); + learner->Predict(p_full, false, &from_full, 0, 0); HostDeviceVector from_hist; - learner->Predict(p_hist, false, &from_hist); + learner->Predict(p_hist, false, &from_hist, 0, 0); for (size_t i = 0; i < rows; ++i) { EXPECT_NEAR(from_hist.ConstHostVector()[i], @@ -157,20 +157,20 @@ void TestPredictionWithLesserFeatures(std::string predictor_name) { learner->SaveConfig(&config); ASSERT_EQ(get(config["learner"]["gradient_booster"]["gbtree_train_param"]["predictor"]), predictor_name); - learner->Predict(m_test, false, &prediction); + learner->Predict(m_test, false, &prediction, 0, 0); ASSERT_EQ(prediction.Size(), kRows); auto m_invalid = RandomDataGenerator(kRows, kTrainCols + 1, 0.5).GenerateDMatrix(false); - ASSERT_THROW({learner->Predict(m_invalid, false, &prediction);}, dmlc::Error); + ASSERT_THROW({learner->Predict(m_invalid, false, &prediction, 0, 0);}, dmlc::Error); #if defined(XGBOOST_USE_CUDA) HostDeviceVector from_cpu; learner->SetParam("predictor", "cpu_predictor"); - learner->Predict(m_test, false, &from_cpu); + learner->Predict(m_test, false, &from_cpu, 0, 0); HostDeviceVector from_cuda; learner->SetParam("predictor", "gpu_predictor"); - learner->Predict(m_test, false, &from_cuda); + learner->Predict(m_test, false, &from_cuda, 0, 0); auto const& h_cpu = from_cpu.ConstHostVector(); auto const& h_gpu = from_cuda.ConstHostVector(); diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index 727610569..19a78645d 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -221,9 +221,10 @@ TEST(Learner, MultiThreadedPredict) { auto &entry = learner->GetThreadLocal().prediction_entry; HostDeviceVector predictions; for (size_t iter = 0; iter < kIters; ++iter) { - learner->Predict(p_data, false, &entry.predictions); - learner->Predict(p_data, false, &predictions, 0, true); // leaf - learner->Predict(p_data, false, &predictions, 0, false, true); // contribs + learner->Predict(p_data, false, &entry.predictions, 0, 0); + + learner->Predict(p_data, false, &predictions, 0, 0, false, true); // leaf + learner->Predict(p_data, false, &predictions, 0, 0, false, false, true); // contribs } }); } diff --git a/tests/python-gpu/test_from_cupy.py b/tests/python-gpu/test_from_cupy.py index 6db412ce3..60b73e675 100644 --- a/tests/python-gpu/test_from_cupy.py +++ b/tests/python-gpu/test_from_cupy.py @@ -112,17 +112,24 @@ def _test_cupy_metainfo(DMatrixT): @pytest.mark.skipif(**tm.no_sklearn()) def test_cupy_training_with_sklearn(): import cupy as cp + np.random.seed(1) cp.random.seed(1) - X = cp.random.randn(50, 10, dtype='float32') - y = (cp.random.randn(50, dtype='float32') > 0).astype('int8') + X = cp.random.randn(50, 10, dtype="float32") + y = (cp.random.randn(50, dtype="float32") > 0).astype("int8") weights = np.random.random(50) + 1 cupy_weights = cp.array(weights) base_margin = np.random.random(50) cupy_base_margin = cp.array(base_margin) - clf = xgb.XGBClassifier(gpu_id=0, tree_method='gpu_hist', use_label_encoder=False) - clf.fit(X, y, sample_weight=cupy_weights, base_margin=cupy_base_margin, eval_set=[(X, y)]) + clf = xgb.XGBClassifier(gpu_id=0, tree_method="gpu_hist", use_label_encoder=False) + clf.fit( + X, + y, + sample_weight=cupy_weights, + base_margin=cupy_base_margin, + eval_set=[(X, y)], + ) pred = clf.predict(X) assert np.array_equal(np.unique(pred), np.array([0, 1])) diff --git a/tests/python-gpu/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask.py index 51ef9c1e3..880e99083 100644 --- a/tests/python-gpu/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask.py @@ -16,13 +16,15 @@ if sys.platform.startswith("win"): pytest.skip("Skipping dask tests on Windows", allow_module_level=True) sys.path.append("tests/python") -from test_with_dask import run_empty_dmatrix_reg # noqa -from test_with_dask import run_empty_dmatrix_cls # noqa -from test_with_dask import _get_client_workers # noqa -from test_with_dask import generate_array # noqa -from test_with_dask import kCols as random_cols # noqa -from test_with_dask import suppress # noqa -import testing as tm # noqa +from test_with_dask import run_empty_dmatrix_reg # noqa +from test_with_dask import run_boost_from_prediction # noqa +from test_with_dask import run_dask_classifier # noqa +from test_with_dask import run_empty_dmatrix_cls # noqa +from test_with_dask import _get_client_workers # noqa +from test_with_dask import generate_array # noqa +from test_with_dask import kCols as random_cols # noqa +from test_with_dask import suppress # noqa +import testing as tm # noqa try: @@ -132,9 +134,9 @@ def run_gpu_hist( num_rounds: int, dataset: tm.TestDataset, DMatrixT: Type, - client: Client + client: Client, ) -> None: - params['tree_method'] = 'gpu_hist' + params["tree_method"] = "gpu_hist" params = dataset.set_params(params) # It doesn't make sense to distribute a completely # empty dataset. @@ -143,26 +145,40 @@ def run_gpu_hist( chunk = 128 X = to_cp(dataset.X, DMatrixT) - X = da.from_array(X, - chunks=(chunk, dataset.X.shape[1])) + X = da.from_array(X, chunks=(chunk, dataset.X.shape[1])) y = to_cp(dataset.y, DMatrixT) - y = da.from_array(y, chunks=(chunk, )) + y = da.from_array(y, chunks=(chunk,)) if dataset.w is not None: w = to_cp(dataset.w, DMatrixT) - w = da.from_array(w, chunks=(chunk, )) + w = da.from_array(w, chunks=(chunk,)) else: w = None if DMatrixT is dxgb.DaskDeviceQuantileDMatrix: - m = DMatrixT(client, data=X, label=y, weight=w, - max_bin=params.get('max_bin', 256)) + m = DMatrixT( + client, data=X, label=y, weight=w, max_bin=params.get("max_bin", 256) + ) else: m = DMatrixT(client, data=X, label=y, weight=w) - history = dxgb.train(client, params=params, dtrain=m, - num_boost_round=num_rounds, - evals=[(m, 'train')])['history'] + history = dxgb.train( + client, + params=params, + dtrain=m, + num_boost_round=num_rounds, + evals=[(m, "train")], + )["history"] note(history) - assert tm.non_increasing(history['train'][dataset.metric]) + assert tm.non_increasing(history["train"][dataset.metric]) + + +def test_boost_from_prediction(local_cuda_cluster: LocalCUDACluster) -> None: + import cudf + from sklearn.datasets import load_breast_cancer + with Client(local_cuda_cluster) as client: + X_, y_ = load_breast_cancer(return_X_y=True) + X = dd.from_array(X_, chunksize=100).map_partitions(cudf.from_pandas) + y = dd.from_array(y_, chunksize=100).map_partitions(cudf.from_pandas) + run_boost_from_prediction(X, y, "gpu_hist", client) class TestDistributedGPU: @@ -246,6 +262,20 @@ class TestDistributedGPU: dump = booster.get_dump(dump_format='json') assert len(dump) - booster.best_iteration == early_stopping_rounds + 1 + @pytest.mark.skipif(**tm.no_cudf()) + @pytest.mark.skipif(**tm.no_dask()) + @pytest.mark.skipif(**tm.no_dask_cuda()) + @pytest.mark.parametrize("model", ["boosting"]) + def test_dask_classifier(self, model, local_cuda_cluster: LocalCUDACluster) -> None: + import dask_cudf + with Client(local_cuda_cluster) as client: + X_, y_, w_ = generate_array(with_weights=True) + y_ = (y_ * 10).astype(np.int32) + X = dask_cudf.from_dask_dataframe(dd.from_dask_array(X_)) + y = dask_cudf.from_dask_dataframe(dd.from_dask_array(y_)) + w = dask_cudf.from_dask_dataframe(dd.from_dask_array(w_)) + run_dask_classifier(X, y, w, model, client) + @pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_dask_cuda()) @pytest.mark.mgpu diff --git a/tests/python/test_basic_models.py b/tests/python/test_basic_models.py index d1cfa33e5..beef2f331 100644 --- a/tests/python/test_basic_models.py +++ b/tests/python/test_basic_models.py @@ -434,7 +434,13 @@ class TestModels: booster[...:end] = booster sliced_0 = booster[1:3] + np.testing.assert_allclose( + booster.predict(dtrain, iteration_range=(1, 3)), sliced_0.predict(dtrain) + ) sliced_1 = booster[3:7] + np.testing.assert_allclose( + booster.predict(dtrain, iteration_range=(3, 7)), sliced_1.predict(dtrain) + ) predt_0 = sliced_0.predict(dtrain, output_margin=True) predt_1 = sliced_1.predict(dtrain, output_margin=True) diff --git a/tests/python/test_predict.py b/tests/python/test_predict.py index 2dfd40005..174a4a13e 100644 --- a/tests/python/test_predict.py +++ b/tests/python/test_predict.py @@ -47,30 +47,27 @@ def run_predict_leaf(predictor): empty_leaf = booster.predict(empty, pred_leaf=True) assert empty_leaf.shape[0] == 0 - leaf = booster.predict(m, pred_leaf=True) + leaf = booster.predict(m, pred_leaf=True, strict_shape=True) assert leaf.shape[0] == rows - assert leaf.shape[1] == classes * num_parallel_tree * num_boost_round + assert leaf.shape[1] == num_boost_round + assert leaf.shape[2] == classes + assert leaf.shape[3] == num_parallel_tree for i in range(rows): - row = leaf[i, ...] for j in range(num_boost_round): - start = classes * num_parallel_tree * j - end = classes * num_parallel_tree * (j + 1) - layer = row[start: end] - for c in range(classes): - tree_group = layer[c * num_parallel_tree: (c + 1) * num_parallel_tree] + for k in range(classes): + tree_group = leaf[i, j, k, :] assert tree_group.shape[0] == num_parallel_tree - # no subsampling so tree in same forest should output same - # leaf. + # No sampling, all trees within forest are the same assert np.all(tree_group == tree_group[0]) ntree_limit = 2 sliced = booster.predict( - m, pred_leaf=True, ntree_limit=num_parallel_tree * ntree_limit + m, pred_leaf=True, ntree_limit=num_parallel_tree * ntree_limit, strict_shape=True ) first = sliced[0, ...] - assert first.shape[0] == classes * num_parallel_tree * ntree_limit + assert np.prod(first.shape) == classes * num_parallel_tree * ntree_limit return leaf @@ -78,6 +75,23 @@ def test_predict_leaf(): run_predict_leaf('cpu_predictor') +def test_predict_shape(): + from sklearn.datasets import load_boston + X, y = load_boston(return_X_y=True) + reg = xgb.XGBRegressor(n_estimators=1) + reg.fit(X, y) + predt = reg.get_booster().predict(xgb.DMatrix(X), strict_shape=True) + assert len(predt.shape) == 2 + assert predt.shape[0] == X.shape[0] + assert predt.shape[1] == 1 + + contrib = reg.get_booster().predict( + xgb.DMatrix(X), pred_contribs=True, strict_shape=True + ) + assert len(contrib.shape) == 3 + assert contrib.shape[1] == 1 + + class TestInplacePredict: '''Tests for running inplace prediction''' @classmethod @@ -92,8 +106,7 @@ class TestInplacePredict: dtrain = xgb.DMatrix(cls.X, cls.y) - cls.booster = xgb.train({'tree_method': 'hist'}, - dtrain, num_boost_round=10) + cls.booster = xgb.train({'tree_method': 'hist'}, dtrain, num_boost_round=10) cls.test = xgb.DMatrix(cls.X[:10, ...]) diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 5a66a5ec1..6205a3275 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -159,12 +159,9 @@ def test_dask_predict_shape_infer(client: "Client") -> None: assert prediction.shape[1] == 3 -@pytest.mark.parametrize("tree_method", ["hist", "approx"]) -def test_boost_from_prediction(tree_method: str, client: "Client") -> None: - from sklearn.datasets import load_breast_cancer - X_, y_ = load_breast_cancer(return_X_y=True) - - X, y = dd.from_array(X_, chunksize=100), dd.from_array(y_, chunksize=100) +def run_boost_from_prediction( + X: xgb.dask._DaskCollection, y: xgb.dask._DaskCollection, tree_method: str, client: "Client" +) -> None: model_0 = xgb.dask.DaskXGBClassifier( learning_rate=0.3, random_state=0, n_estimators=4, tree_method=tree_method) @@ -202,6 +199,30 @@ def test_boost_from_prediction(tree_method: str, client: "Client") -> None: assert margined_res[i] < unmargined_res[i] +@pytest.mark.parametrize("tree_method", ["hist", "approx"]) +def test_boost_from_prediction(tree_method: str, client: "Client") -> None: + from sklearn.datasets import load_breast_cancer + X_, y_ = load_breast_cancer(return_X_y=True) + X, y = dd.from_array(X_, chunksize=100), dd.from_array(y_, chunksize=100) + run_boost_from_prediction(X, y, tree_method, client) + + +def test_inplace_predict(client: "Client") -> None: + from sklearn.datasets import load_boston + X_, y_ = load_boston(return_X_y=True) + X, y = dd.from_array(X_, chunksize=32), dd.from_array(y_, chunksize=32) + reg = xgb.dask.DaskXGBRegressor(n_estimators=4).fit(X, y) + booster = reg.get_booster() + base_margin = y + + inplace = xgb.dask.inplace_predict( + client, booster, X, base_margin=base_margin + ).compute() + Xy = xgb.dask.DaskDMatrix(client, X, base_margin=base_margin) + copied = xgb.dask.predict(client, booster, Xy).compute() + np.testing.assert_allclose(inplace, copied) + + def test_dask_missing_value_reg(client: "Client") -> None: X_0 = np.ones((20 // 2, kCols)) X_1 = np.zeros((20 // 2, kCols)) @@ -288,10 +309,13 @@ def test_dask_regressor(model: str, client: "Client") -> None: assert forest == 2 -@pytest.mark.parametrize("model", ["boosting", "rf"]) -def test_dask_classifier(model: str, client: "Client") -> None: - X, y, w = generate_array(with_weights=True) - y = (y * 10).astype(np.int32) +def run_dask_classifier( + X: xgb.dask._DaskCollection, + y: xgb.dask._DaskCollection, + w: xgb.dask._DaskCollection, + model: str, + client: "Client", +) -> None: if model == "boosting": classifier = xgb.dask.DaskXGBClassifier( verbosity=1, n_estimators=2, eval_metric="merror" @@ -306,14 +330,13 @@ def test_dask_classifier(model: str, client: "Client") -> None: classifier.client = client classifier.fit(X, y, sample_weight=w, eval_set=[(X, y)]) - prediction = classifier.predict(X) + prediction = classifier.predict(X).compute() assert prediction.ndim == 1 assert prediction.shape[0] == kRows history = classifier.evals_result() - assert isinstance(prediction, da.Array) assert isinstance(history, dict) assert list(history.keys())[0] == "validation_0" @@ -332,7 +355,7 @@ def test_dask_classifier(model: str, client: "Client") -> None: assert forest == 2 # Test .predict_proba() - probas = classifier.predict_proba(X) + probas = classifier.predict_proba(X).compute() assert classifier.n_classes_ == 10 assert probas.ndim == 2 assert probas.shape[0] == kRows @@ -341,18 +364,33 @@ def test_dask_classifier(model: str, client: "Client") -> None: cls_booster = classifier.get_booster() single_node_proba = cls_booster.inplace_predict(X.compute()) - np.testing.assert_allclose(single_node_proba, probas.compute()) + # test shared by CPU and GPU + if isinstance(single_node_proba, np.ndarray): + np.testing.assert_allclose(single_node_proba, probas) + else: + import cupy + cupy.testing.assert_allclose(single_node_proba, probas) - # Test with dataframe. - X_d = dd.from_dask_array(X) - y_d = dd.from_dask_array(y) - classifier.fit(X_d, y_d) + # Test with dataframe, not shared with GPU as cupy doesn't work well with da.unique. + if isinstance(X, da.Array): + X_d: dd.DataFrame = X.to_dask_dataframe() - assert classifier.n_classes_ == 10 - prediction = classifier.predict(X_d).compute() + assert classifier.n_classes_ == 10 + prediction_df = classifier.predict(X_d).compute() - assert prediction.ndim == 1 - assert prediction.shape[0] == kRows + assert prediction_df.ndim == 1 + assert prediction_df.shape[0] == kRows + np.testing.assert_allclose(prediction_df, prediction) + + probas = classifier.predict_proba(X).compute() + np.testing.assert_allclose(single_node_proba, probas) + + +@pytest.mark.parametrize("model", ["boosting", "rf"]) +def test_dask_classifier(model: str, client: "Client") -> None: + X, y, w = generate_array(with_weights=True) + y = (y * 10).astype(np.int32) + run_dask_classifier(X, y, w, model, client) @pytest.mark.skipif(**tm.no_sklearn()) @@ -913,9 +951,9 @@ class TestWithDask: train = xgb.dask.DaskDMatrix(client, dX, dy) dX = dd.from_array(X) - dX = client.persist(dX, workers={dX: workers[1]}) + dX = client.persist(dX, workers=workers[1]) dy = dd.from_array(y) - dy = client.persist(dy, workers={dy: workers[1]}) + dy = client.persist(dy, workers=workers[1]) valid = xgb.dask.DaskDMatrix(client, dX, dy) merged = xgb.dask._get_workers_from_data(train, evals=[(valid, 'Valid')]) @@ -1060,6 +1098,16 @@ class TestWithDask: assert_shape(shap.shape) assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5) + X = dd.from_dask_array(X).repartition(npartitions=32) + y = dd.from_dask_array(y).repartition(npartitions=32) + shap_df = xgb.dask.predict( + client, booster, X, pred_contribs=True, validate_features=False + ).compute() + assert_shape(shap_df.shape) + assert np.allclose( + np.sum(shap_df, axis=len(shap_df.shape) - 1), margin, 1e-5, 1e-5 + ) + def run_shap_cls_sklearn(self, X: Any, y: Any, client: "Client") -> None: X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32) cls = xgb.dask.DaskXGBClassifier(n_estimators=4)