[breaking] Add prediction fucntion for DMatrix and use inplace predict for dask. (#6668)

* Add a new API function for predicting on `DMatrix`.  This function aligns
with rest of the `XGBoosterPredictFrom*` functions on semantic of function
arguments.
* Purge `ntree_limit` from libxgboost, use iteration instead.
* [dask] Use `inplace_predict` by default for dask sklearn models.
* [dask] Run prediction shape inference on worker instead of client.

The breaking change is in the Python sklearn `apply` function, I made it to be
consistent with other prediction functions where `best_iteration` is used by
default.
This commit is contained in:
Jiaming Yuan 2021-02-08 18:26:32 +08:00 committed by GitHub
parent dbb5208a0a
commit 4656b09d5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 1134 additions and 604 deletions

View File

@ -704,8 +704,9 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle,
const char *evnames[], const char *evnames[],
bst_ulong len, bst_ulong len,
const char **out_result); const char **out_result);
/*! /*!
* \brief make prediction based on dmat * \brief make prediction based on dmat (deprecated, use `XGBoosterPredictFromDMatrix` instead)
* \param handle handle * \param handle handle
* \param dmat data matrix * \param dmat data matrix
* \param option_mask bit-mask of options taken in prediction, possible values * \param option_mask bit-mask of options taken in prediction, possible values
@ -734,6 +735,165 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
int training, int training,
bst_ulong *out_len, bst_ulong *out_len,
const float **out_result); const float **out_result);
/*!
* \brief Make prediction from DMatrix, replacing `XGBoosterPredict`.
*
* \param handle Booster handle
* \param dmat DMatrix handle
* \param c_json_config String encoded predict configuration in JSON format.
*
* "type": [0, 5]
* 0: normal prediction
* 1: output margin
* 2: predict contribution
* 3: predict approxmated contribution
* 4: predict feature interaction
* 5: predict leaf
* "training": bool
* Whether the prediction function is used as part of a training loop. **Not used
* for inplace prediction**.
*
* Prediction can be run in 2 scenarios:
* 1. Given data matrix X, obtain prediction y_pred from the model.
* 2. Obtain the prediction for computing gradients. For example, DART booster performs dropout
* during training, and the prediction result will be different from the one obtained by normal
* inference step due to dropped trees.
* Set training=false for the first scenario. Set training=true for the second
* scenario. The second scenario applies when you are defining a custom objective
* function.
* "iteration_begin": int
* Beginning iteration of prediction.
* "iteration_end": int
* End iteration of prediction. Set to 0 this will become the size of tree model.
* "strict_shape": bool
* Whether should we reshape the output with stricter rules. If set to true,
* normal/margin/contrib/interaction predict will output consistent shape
* disregarding the use of multi-class model, and leaf prediction will output 4-dim
* array representing: (n_samples, n_iterations, n_classes, n_trees_in_forest)
*
* Run a normal prediction with strict output shape, 2 dim for softprob , 1 dim for others.
* \code
* {
* "type": 0,
* "training": False,
* "iteration_begin": 0,
* "iteration_end": 0,
* "strict_shape": true,
* }
* \endcode
*
* \param out_shape Shape of output prediction (copy before use).
* \param out_dim Dimension of output prediction.
* \param out_result Buffer storing prediction value (copy before use).
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle,
DMatrixHandle dmat,
char const* c_json_config,
bst_ulong const **out_shape,
bst_ulong *out_dim,
float const **out_result);
/*
* \brief Inplace prediction from CPU dense matrix.
*
* \param handle Booster handle.
* \param values JSON encoded __array_interface__ to values.
* \param c_json_config See `XGBoosterPredictFromDMatrix` for more info.
*
* Additional fields for inplace prediction are:
* "missing": float
*
* \param m An optional (NULL if not available) proxy DMatrix instance
* storing meta info.
*
* \param out_shape See `XGBoosterPredictFromDMatrix` for more info.
* \param out_dim See `XGBoosterPredictFromDMatrix` for more info.
* \param out_result See `XGBoosterPredictFromDMatrix` for more info.
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle,
char const *values,
char const *c_json_config,
DMatrixHandle m,
bst_ulong const **out_shape,
bst_ulong *out_dim,
const float **out_result);
/*
* \brief Inplace prediction from CPU CSR matrix.
*
* \param handle Booster handle.
* \param indptr JSON encoded __array_interface__ to row pointer in CSR.
* \param indices JSON encoded __array_interface__ to column indices in CSR.
* \param values JSON encoded __array_interface__ to values in CSR..
* \param ncol Number of features in data.
* \param c_json_config See `XGBoosterPredictFromDMatrix` for more info.
* Additional fields for inplace prediction are:
* "missing": float
*
* \param m An optional (NULL if not available) proxy DMatrix instance
* storing meta info.
*
* \param out_shape See `XGBoosterPredictFromDMatrix` for more info.
* \param out_dim See `XGBoosterPredictFromDMatrix` for more info.
* \param out_result See `XGBoosterPredictFromDMatrix` for more info.
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, char const *indptr,
char const *indices, char const *values,
bst_ulong ncol,
char const *c_json_config, DMatrixHandle m,
bst_ulong const **out_shape,
bst_ulong *out_dim,
const float **out_result);
/*
* \brief Inplace prediction from CUDA Dense matrix (cupy in Python).
*
* \param handle Booster handle
* \param values JSON encoded __cuda_array_interface__ to values.
* \param c_json_config See `XGBoosterPredictFromDMatrix` for more info.
* Additional fields for inplace prediction are:
* "missing": float
*
* \param m An optional (NULL if not available) proxy DMatrix instance
* storing meta info.
* \param out_shape See `XGBoosterPredictFromDMatrix` for more info.
* \param out_dim See `XGBoosterPredictFromDMatrix` for more info.
* \param out_result See `XGBoosterPredictFromDMatrix` for more info.
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterPredictFromCudaArray(
BoosterHandle handle, char const *values, char const *c_json_config,
DMatrixHandle m, bst_ulong const **out_shape, bst_ulong *out_dim,
const float **out_result);
/*
* \brief Inplace prediction from CUDA dense dataframe (cuDF in Python).
*
* \param handle Booster handle
* \param values List of __cuda_array_interface__ for all columns encoded in JSON list.
* \param c_json_config See `XGBoosterPredictFromDMatrix` for more info.
* Additional fields for inplace prediction are:
* "missing": float
*
* \param m An optional (NULL if not available) proxy DMatrix instance
* storing meta info.
* \param out_shape See `XGBoosterPredictFromDMatrix` for more info.
* \param out_dim See `XGBoosterPredictFromDMatrix` for more info.
* \param out_result See `XGBoosterPredictFromDMatrix` for more info.
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterPredictFromCudaColumnar(
BoosterHandle handle, char const *values, char const *c_json_config,
DMatrixHandle m, bst_ulong const **out_shape, bst_ulong *out_dim,
const float **out_result);
/* /*
* ========================== Begin Serialization APIs ========================= * ========================== Begin Serialization APIs =========================

View File

@ -63,7 +63,7 @@ class GradientBooster : public Model, public Configurable {
/*! /*!
* \brief Slice a model using boosting index. The slice m:n indicates taking all trees * \brief Slice a model using boosting index. The slice m:n indicates taking all trees
* that were fit during the boosting rounds m, (m+1), (m+2), ..., (n-1). * that were fit during the boosting rounds m, (m+1), (m+2), ..., (n-1).
* \param layer_begin Begining of boosted tree layer used for prediction. * \param layer_begin Beginning of boosted tree layer used for prediction.
* \param layer_end End of booster layer. 0 means do not limit trees. * \param layer_end End of booster layer. 0 means do not limit trees.
* \param out Output gradient booster * \param out Output gradient booster
*/ */
@ -99,15 +99,14 @@ class GradientBooster : public Model, public Configurable {
* \param out_preds output vector to hold the predictions * \param out_preds output vector to hold the predictions
* \param training Whether the prediction value is used for training. For dart booster * \param training Whether the prediction value is used for training. For dart booster
* drop out is performed during training. * drop out is performed during training.
* \param ntree_limit limit the number of trees used in prediction, * \param layer_begin Beginning of boosted tree layer used for prediction.
* when it equals 0, this means we do not limit * \param layer_end End of booster layer. 0 means do not limit trees.
* number of trees, this parameter is only valid
* for gbtree, but not for gblinear
*/ */
virtual void PredictBatch(DMatrix* dmat, virtual void PredictBatch(DMatrix* dmat,
PredictionCacheEntry* out_preds, PredictionCacheEntry* out_preds,
bool training, bool training,
unsigned ntree_limit = 0) = 0; unsigned layer_begin,
unsigned layer_end) = 0;
/*! /*!
* \brief Inplace prediction. * \brief Inplace prediction.
@ -115,7 +114,7 @@ class GradientBooster : public Model, public Configurable {
* \param x A type erased data adapter. * \param x A type erased data adapter.
* \param missing Missing value in the data. * \param missing Missing value in the data.
* \param [in,out] out_preds The output preds. * \param [in,out] out_preds The output preds.
* \param layer_begin (Optional) Begining of boosted tree layer used for prediction. * \param layer_begin (Optional) Beginning of boosted tree layer used for prediction.
* \param layer_end (Optional) End of booster layer. 0 means do not limit trees. * \param layer_end (Optional) End of booster layer. 0 means do not limit trees.
*/ */
virtual void InplacePredict(dmlc::any const &, std::shared_ptr<DMatrix>, float, virtual void InplacePredict(dmlc::any const &, std::shared_ptr<DMatrix>, float,
@ -132,44 +131,45 @@ class GradientBooster : public Model, public Configurable {
* *
* \param inst the instance you want to predict * \param inst the instance you want to predict
* \param out_preds output vector to hold the predictions * \param out_preds output vector to hold the predictions
* \param ntree_limit limit the number of trees used in prediction * \param layer_begin Beginning of boosted tree layer used for prediction.
* \param layer_end End of booster layer. 0 means do not limit trees.
* \sa Predict * \sa Predict
*/ */
virtual void PredictInstance(const SparsePage::Inst& inst, virtual void PredictInstance(const SparsePage::Inst& inst,
std::vector<bst_float>* out_preds, std::vector<bst_float>* out_preds,
unsigned ntree_limit = 0) = 0; unsigned layer_begin, unsigned layer_end) = 0;
/*! /*!
* \brief predict the leaf index of each tree, the output will be nsample * ntree vector * \brief predict the leaf index of each tree, the output will be nsample * ntree vector
* this is only valid in gbtree predictor * this is only valid in gbtree predictor
* \param dmat feature matrix * \param dmat feature matrix
* \param out_preds output vector to hold the predictions * \param out_preds output vector to hold the predictions
* \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means * \param layer_begin Beginning of boosted tree layer used for prediction.
* we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear * \param layer_end End of booster layer. 0 means do not limit trees.
*/ */
virtual void PredictLeaf(DMatrix* dmat, virtual void PredictLeaf(DMatrix *dmat,
HostDeviceVector<bst_float>* out_preds, HostDeviceVector<bst_float> *out_preds,
unsigned ntree_limit = 0) = 0; unsigned layer_begin, unsigned layer_end) = 0;
/*! /*!
* \brief feature contributions to individual predictions; the output will be a vector * \brief feature contributions to individual predictions; the output will be a vector
* of length (nfeats + 1) * num_output_group * nsample, arranged in that order * of length (nfeats + 1) * num_output_group * nsample, arranged in that order
* \param dmat feature matrix * \param dmat feature matrix
* \param out_contribs output vector to hold the contributions * \param out_contribs output vector to hold the contributions
* \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means * \param layer_begin Beginning of boosted tree layer used for prediction.
* we do not limit number of trees * \param layer_end End of booster layer. 0 means do not limit trees.
* \param approximate use a faster (inconsistent) approximation of SHAP values * \param approximate use a faster (inconsistent) approximation of SHAP values
* \param condition condition on the condition_feature (0=no, -1=cond off, 1=cond on). * \param condition condition on the condition_feature (0=no, -1=cond off, 1=cond on).
* \param condition_feature feature to condition on (i.e. fix) during calculations * \param condition_feature feature to condition on (i.e. fix) during calculations
*/ */
virtual void PredictContribution(DMatrix* dmat, virtual void PredictContribution(DMatrix* dmat,
HostDeviceVector<bst_float>* out_contribs, HostDeviceVector<bst_float>* out_contribs,
unsigned ntree_limit = 0, unsigned layer_begin, unsigned layer_end,
bool approximate = false, int condition = 0, bool approximate = false, int condition = 0,
unsigned condition_feature = 0) = 0; unsigned condition_feature = 0) = 0;
virtual void PredictInteractionContributions(DMatrix* dmat, virtual void PredictInteractionContributions(
HostDeviceVector<bst_float>* out_contribs, DMatrix *dmat, HostDeviceVector<bst_float> *out_contribs,
unsigned ntree_limit, bool approximate) = 0; unsigned layer_begin, unsigned layer_end, bool approximate) = 0;
/*! /*!
* \brief dump the model in the requested format * \brief dump the model in the requested format

View File

@ -113,8 +113,8 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
* \param data input data * \param data input data
* \param output_margin whether to only predict margin value instead of transformed prediction * \param output_margin whether to only predict margin value instead of transformed prediction
* \param out_preds output vector that stores the prediction * \param out_preds output vector that stores the prediction
* \param ntree_limit limit number of trees used for boosted tree * \param layer_begin Beginning of boosted tree layer used for prediction.
* predictor, when it equals 0, this means we are using all the trees * \param layer_end End of booster layer. 0 means do not limit trees.
* \param training Whether the prediction result is used for training * \param training Whether the prediction result is used for training
* \param pred_leaf whether to only predict the leaf index of each tree in a boosted tree predictor * \param pred_leaf whether to only predict the leaf index of each tree in a boosted tree predictor
* \param pred_contribs whether to only predict the feature contributions * \param pred_contribs whether to only predict the feature contributions
@ -124,7 +124,8 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
virtual void Predict(std::shared_ptr<DMatrix> data, virtual void Predict(std::shared_ptr<DMatrix> data,
bool output_margin, bool output_margin,
HostDeviceVector<bst_float> *out_preds, HostDeviceVector<bst_float> *out_preds,
unsigned ntree_limit = 0, unsigned layer_begin,
unsigned layer_end,
bool training = false, bool training = false,
bool pred_leaf = false, bool pred_leaf = false,
bool pred_contribs = false, bool pred_contribs = false,
@ -140,7 +141,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
* \param type Prediction type. * \param type Prediction type.
* \param missing Missing value in the data. * \param missing Missing value in the data.
* \param [in,out] out_preds Pointer to output prediction vector. * \param [in,out] out_preds Pointer to output prediction vector.
* \param layer_begin Begining of boosted tree layer used for prediction. * \param layer_begin Beginning of boosted tree layer used for prediction.
* \param layer_end End of booster layer. 0 means do not limit trees. * \param layer_end End of booster layer. 0 means do not limit trees.
*/ */
virtual void InplacePredict(dmlc::any const &x, virtual void InplacePredict(dmlc::any const &x,

View File

@ -127,12 +127,11 @@ class Predictor {
* \param [in,out] out_preds The output preds. * \param [in,out] out_preds The output preds.
* \param model The model to predict from. * \param model The model to predict from.
* \param tree_begin The tree begin index. * \param tree_begin The tree begin index.
* \param ntree_limit (Optional) The ntree limit. 0 means do not * \param tree_end The tree end index.
* limit trees.
*/ */
virtual void PredictBatch(DMatrix* dmat, PredictionCacheEntry* out_preds, virtual void PredictBatch(DMatrix* dmat, PredictionCacheEntry* out_preds,
const gbm::GBTreeModel& model, int tree_begin, const gbm::GBTreeModel& model, uint32_t tree_begin,
uint32_t const ntree_limit = 0) const = 0; uint32_t tree_end = 0) const = 0;
/** /**
* \brief Inplace prediction. * \brief Inplace prediction.
@ -140,7 +139,7 @@ class Predictor {
* \param model The model to predict from. * \param model The model to predict from.
* \param missing Missing value in the data. * \param missing Missing value in the data.
* \param [in,out] out_preds The output preds. * \param [in,out] out_preds The output preds.
* \param tree_begin (Optional) Begining of boosted trees used for prediction. * \param tree_begin (Optional) Beginning of boosted trees used for prediction.
* \param tree_end (Optional) End of booster trees. 0 means do not limit trees. * \param tree_end (Optional) End of booster trees. 0 means do not limit trees.
* *
* \return True if the data can be handled by current predictor, false otherwise. * \return True if the data can be handled by current predictor, false otherwise.
@ -159,13 +158,13 @@ class Predictor {
* \param inst The instance to predict. * \param inst The instance to predict.
* \param [in,out] out_preds The output preds. * \param [in,out] out_preds The output preds.
* \param model The model to predict from * \param model The model to predict from
* \param ntree_limit (Optional) The ntree limit. * \param tree_end (Optional) The tree end index.
*/ */
virtual void PredictInstance(const SparsePage::Inst& inst, virtual void PredictInstance(const SparsePage::Inst& inst,
std::vector<bst_float>* out_preds, std::vector<bst_float>* out_preds,
const gbm::GBTreeModel& model, const gbm::GBTreeModel& model,
unsigned ntree_limit = 0) const = 0; unsigned tree_end = 0) const = 0;
/** /**
* \brief predict the leaf index of each tree, the output will be nsample * * \brief predict the leaf index of each tree, the output will be nsample *
@ -174,18 +173,14 @@ class Predictor {
* \param [in,out] dmat The input feature matrix. * \param [in,out] dmat The input feature matrix.
* \param [in,out] out_preds The output preds. * \param [in,out] out_preds The output preds.
* \param model Model to make predictions from. * \param model Model to make predictions from.
* \param ntree_limit (Optional) The ntree limit. * \param tree_end (Optional) The tree end index.
*/ */
virtual void PredictLeaf(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds, virtual void PredictLeaf(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
const gbm::GBTreeModel& model, const gbm::GBTreeModel& model,
unsigned ntree_limit = 0) const = 0; unsigned tree_end = 0) const = 0;
/** /**
* \fn virtual void Predictor::PredictContribution( DMatrix* dmat,
* std::vector<bst_float>* out_contribs, const gbm::GBTreeModel& model,
* unsigned ntree_limit = 0) = 0;
*
* \brief feature contributions to individual predictions; the output will be * \brief feature contributions to individual predictions; the output will be
* a vector of length (nfeats + 1) * num_output_group * nsample, arranged in * a vector of length (nfeats + 1) * num_output_group * nsample, arranged in
* that order. * that order.
@ -193,7 +188,7 @@ class Predictor {
* \param [in,out] dmat The input feature matrix. * \param [in,out] dmat The input feature matrix.
* \param [in,out] out_contribs The output feature contribs. * \param [in,out] out_contribs The output feature contribs.
* \param model Model to make predictions from. * \param model Model to make predictions from.
* \param ntree_limit (Optional) The ntree limit. * \param tree_end The tree end index.
* \param tree_weights (Optional) Weights to multiply each tree by. * \param tree_weights (Optional) Weights to multiply each tree by.
* \param approximate Use fast approximate algorithm. * \param approximate Use fast approximate algorithm.
* \param condition Condition on the condition_feature (0=no, -1=cond off, 1=cond on). * \param condition Condition on the condition_feature (0=no, -1=cond off, 1=cond on).
@ -203,7 +198,7 @@ class Predictor {
virtual void PredictContribution(DMatrix* dmat, virtual void PredictContribution(DMatrix* dmat,
HostDeviceVector<bst_float>* out_contribs, HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, const gbm::GBTreeModel& model,
unsigned ntree_limit = 0, unsigned tree_end = 0,
std::vector<bst_float>* tree_weights = nullptr, std::vector<bst_float>* tree_weights = nullptr,
bool approximate = false, bool approximate = false,
int condition = 0, int condition = 0,
@ -212,7 +207,7 @@ class Predictor {
virtual void PredictInteractionContributions(DMatrix* dmat, virtual void PredictInteractionContributions(DMatrix* dmat,
HostDeviceVector<bst_float>* out_contribs, HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, const gbm::GBTreeModel& model,
unsigned ntree_limit = 0, unsigned tree_end = 0,
std::vector<bst_float>* tree_weights = nullptr, std::vector<bst_float>* tree_weights = nullptr,
bool approximate = false) const = 0; bool approximate = false) const = 0;

View File

@ -96,6 +96,24 @@ def from_cstr_to_pystr(data, length):
return res return res
def _convert_ntree_limit(booster, ntree_limit, iteration_range):
if ntree_limit is not None and ntree_limit != 0:
warnings.warn(
"ntree_limit is deprecated, use `iteration_range` or model "
"slicing instead.",
UserWarning
)
if iteration_range is not None and iteration_range[1] != 0:
raise ValueError(
"Only one of `iteration_range` and `ntree_limit` can be non zero."
)
num_parallel_tree, num_groups = _get_booster_layer_trees(booster)
num_parallel_tree = max([num_parallel_tree, 1])
num_groups = max([num_groups, 1])
iteration_range = (0, ntree_limit // num_parallel_tree)
return iteration_range
def _expect(expectations, got): def _expect(expectations, got):
"""Translate input error into string. """Translate input error into string.
@ -1111,6 +1129,34 @@ Objective = Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]]
Metric = Callable[[np.ndarray, DMatrix], Tuple[str, float]] Metric = Callable[[np.ndarray, DMatrix], Tuple[str, float]]
def _get_booster_layer_trees(model: "Booster") -> Tuple[int, int]:
"""Get number of trees added to booster per-iteration. This function will be removed
once `best_ntree_limit` is dropped in favor of `best_iteration`. Returns
`num_parallel_tree` and `num_groups`.
"""
config = json.loads(model.save_config())
booster = config["learner"]["gradient_booster"]["name"]
if booster == "gblinear":
num_parallel_tree = 0
elif booster == "dart":
num_parallel_tree = int(
config["learner"]["gradient_booster"]["gbtree"]["gbtree_train_param"][
"num_parallel_tree"
]
)
elif booster == "gbtree":
num_parallel_tree = int(
config["learner"]["gradient_booster"]["gbtree_train_param"][
"num_parallel_tree"
]
)
else:
raise ValueError(f"Unknown booster: {booster}")
num_groups = int(config["learner"]["learner_model_param"]["num_class"])
return num_parallel_tree, num_groups
class Booster(object): class Booster(object):
# pylint: disable=too-many-public-methods # pylint: disable=too-many-public-methods
"""A Booster of XGBoost. """A Booster of XGBoost.
@ -1497,16 +1543,20 @@ class Booster(object):
return self.eval_set([(data, name)], iteration) return self.eval_set([(data, name)], iteration)
# pylint: disable=too-many-function-args # pylint: disable=too-many-function-args
def predict(self, def predict(
data, self,
output_margin=False, data: DMatrix,
ntree_limit=0, output_margin: bool = False,
pred_leaf=False, ntree_limit: int = 0,
pred_contribs=False, pred_leaf: bool = False,
approx_contribs=False, pred_contribs: bool = False,
pred_interactions=False, approx_contribs: bool = False,
validate_features=True, pred_interactions: bool = False,
training=False): validate_features: bool = True,
training: bool = False,
iteration_range: Tuple[int, int] = (0, 0),
strict_shape: bool = False,
) -> np.ndarray:
"""Predict with data. """Predict with data.
.. note:: This function is not thread safe except for ``gbtree`` booster. .. note:: This function is not thread safe except for ``gbtree`` booster.
@ -1518,33 +1568,32 @@ class Booster(object):
Parameters Parameters
---------- ----------
data : DMatrix data :
The dmatrix storing the input. The dmatrix storing the input.
output_margin : bool output_margin :
Whether to output the raw untransformed margin value. Whether to output the raw untransformed margin value.
ntree_limit : int ntree_limit :
Limit number of trees in the prediction; defaults to 0 (use all Deprecated, use `iteration_range` instead.
trees).
pred_leaf : bool pred_leaf :
When this option is on, the output will be a matrix of (nsample, When this option is on, the output will be a matrix of (nsample,
ntrees) with each record indicating the predicted leaf index of ntrees) with each record indicating the predicted leaf index of
each sample in each tree. Note that the leaf index of a tree is each sample in each tree. Note that the leaf index of a tree is
unique per tree, so you may find leaf 1 in both tree 1 and tree 0. unique per tree, so you may find leaf 1 in both tree 1 and tree 0.
pred_contribs : bool pred_contribs :
When this is True the output will be a matrix of size (nsample, When this is True the output will be a matrix of size (nsample,
nfeats + 1) with each record indicating the feature contributions nfeats + 1) with each record indicating the feature contributions
(SHAP values) for that prediction. The sum of all feature (SHAP values) for that prediction. The sum of all feature
contributions is equal to the raw untransformed margin value of the contributions is equal to the raw untransformed margin value of the
prediction. Note the final column is the bias term. prediction. Note the final column is the bias term.
approx_contribs : bool approx_contribs :
Approximate the contributions of each feature Approximate the contributions of each feature
pred_interactions : bool pred_interactions :
When this is True the output will be a matrix of size (nsample, When this is True the output will be a matrix of size (nsample,
nfeats + 1, nfeats + 1) indicating the SHAP interaction values for nfeats + 1, nfeats + 1) indicating the SHAP interaction values for
each pair of features. The sum of each row (or column) of the each pair of features. The sum of each row (or column) of the
@ -1553,17 +1602,33 @@ class Booster(object):
untransformed margin value of the prediction. Note the last row and untransformed margin value of the prediction. Note the last row and
column correspond to the bias term. column correspond to the bias term.
validate_features : bool validate_features :
When this is True, validate that the Booster's and data's When this is True, validate that the Booster's and data's
feature_names are identical. Otherwise, it is assumed that the feature_names are identical. Otherwise, it is assumed that the
feature_names are the same. feature_names are the same.
training : bool training :
Whether the prediction value is used for training. This can effect Whether the prediction value is used for training. This can effect
`dart` booster, which performs dropouts during training iterations. `dart` booster, which performs dropouts during training iterations.
.. versionadded:: 1.0.0 .. versionadded:: 1.0.0
iteration_range :
Specifies which layer of trees are used in prediction. For example, if a
random forest is trained with 100 rounds. Specifying `iteration_range=(10,
20)`, then only the forests built during [10, 20) (half open set) rounds are
used in this prediction.
.. versionadded:: 1.4.0
strict_shape :
When set to True, output shape is invariant to whether classification is used.
For both value and margin prediction, the output shape is (n_samples,
n_groups), n_groups == 1 when multi-class is not used. Default to False, in
which case the output shape can be (n_samples, ) if multi-class is not used.
.. versionadded:: 1.4.0
.. note:: Using ``predict()`` with DART booster .. note:: Using ``predict()`` with DART booster
If the booster object is DART type, ``predict()`` will not perform If the booster object is DART type, ``predict()`` will not perform
@ -1575,64 +1640,50 @@ class Booster(object):
prediction : numpy array prediction : numpy array
""" """
option_mask = 0x00
if output_margin:
option_mask |= 0x01
if pred_leaf:
option_mask |= 0x02
if pred_contribs:
option_mask |= 0x04
if approx_contribs:
option_mask |= 0x08
if pred_interactions:
option_mask |= 0x10
if not isinstance(data, DMatrix): if not isinstance(data, DMatrix):
raise TypeError('Expecting data to be a DMatrix object, got: ', raise TypeError('Expecting data to be a DMatrix object, got: ', type(data))
type(data))
if validate_features: if validate_features:
self._validate_features(data) self._validate_features(data)
iteration_range = _convert_ntree_limit(self, ntree_limit, iteration_range)
args = {
"type": 0,
"training": training,
"iteration_begin": iteration_range[0],
"iteration_end": iteration_range[1],
"strict_shape": strict_shape,
}
length = c_bst_ulong() def assign_type(t: int) -> None:
preds = ctypes.POINTER(ctypes.c_float)() if args["type"] != 0:
_check_call(_LIB.XGBoosterPredict(self.handle, data.handle, raise ValueError("One type of prediction at a time.")
ctypes.c_int(option_mask), args["type"] = t
ctypes.c_uint(ntree_limit),
ctypes.c_int(training),
ctypes.byref(length),
ctypes.byref(preds)))
preds = ctypes2numpy(preds, length.value, np.float32)
if pred_leaf:
preds = preds.astype(np.int32, copy=False)
nrow = data.num_row()
if preds.size != nrow and preds.size % nrow == 0:
chunk_size = int(preds.size / nrow)
if output_margin:
assign_type(1)
if pred_contribs:
assign_type(2 if not approx_contribs else 3)
if pred_interactions: if pred_interactions:
ngroup = int(chunk_size / ((data.num_col() + 1) * assign_type(4)
(data.num_col() + 1))) if pred_leaf:
if ngroup == 1: assign_type(5)
preds = preds.reshape(nrow, preds = ctypes.POINTER(ctypes.c_float)()
data.num_col() + 1, shape = ctypes.POINTER(c_bst_ulong)()
data.num_col() + 1) dims = c_bst_ulong()
else: _check_call(
preds = preds.reshape(nrow, ngroup, _LIB.XGBoosterPredictFromDMatrix(
data.num_col() + 1, self.handle,
data.num_col() + 1) data.handle,
elif pred_contribs: from_pystr_to_cstr(json.dumps(args)),
ngroup = int(chunk_size / (data.num_col() + 1)) ctypes.byref(shape),
if ngroup == 1: ctypes.byref(dims),
preds = preds.reshape(nrow, data.num_col() + 1) ctypes.byref(preds)
else: )
preds = preds.reshape(nrow, ngroup, data.num_col() + 1) )
else: return _prediction_output(shape, dims, preds, False)
preds = preds.reshape(nrow, chunk_size)
return preds
def inplace_predict( def inplace_predict(
self, self,
data, data: Any,
iteration_range: Tuple[int, int] = (0, 0), iteration_range: Tuple[int, int] = (0, 0),
predict_type: str = "value", predict_type: str = "value",
missing: float = np.nan, missing: float = np.nan,
@ -1665,26 +1716,24 @@ class Booster(object):
The input data, must not be a view for numpy array. Set The input data, must not be a view for numpy array. Set
``predictor`` to ``gpu_predictor`` for running prediction on CuPy ``predictor`` to ``gpu_predictor`` for running prediction on CuPy
array or CuDF DataFrame. array or CuDF DataFrame.
iteration_range : tuple iteration_range :
Specifies which layer of trees are used in prediction. For See :py:meth:`xgboost.Booster.predict` for details.
example, if a random forest is trained with 100 rounds. Specifying predict_type :
`iteration_range=(10, 20)`, then only the forests built during [10,
20) (open set) rounds are used in this prediction.
predict_type : str
* `value` Output model prediction values. * `value` Output model prediction values.
* `margin` Output the raw untransformed margin value. * `margin` Output the raw untransformed margin value.
missing : float missing :
Value in the input data which needs to be present as a missing See :py:obj:`xgboost.DMatrix` for details.
value.
validate_features: validate_features:
See :py:meth:`xgboost.Booster.predict` for details. See :py:meth:`xgboost.Booster.predict` for details.
base_margin: base_margin:
See :py:obj:`xgboost.DMatrix` for details. See :py:obj:`xgboost.DMatrix` for details.
.. versionadded:: 1.4.0
strict_shape: strict_shape:
When set to True, output shape is invariant to whether classification is used. See :py:meth:`xgboost.Booster.predict` for details.
For both value and margin prediction, the output shape is (n_samples,
n_groups), n_groups == 1 when multi-class is not used. Default to False, in .. versionadded:: 1.4.0
which case the output shape can be (n_samples, ) if multi-class is not used.
Returns Returns
------- -------
@ -1772,7 +1821,7 @@ class Booster(object):
interface["mask"] = interface["mask"].__cuda_array_interface__ interface["mask"] = interface["mask"].__cuda_array_interface__
interface_str = bytes(json.dumps(interface, indent=2), "utf-8") interface_str = bytes(json.dumps(interface, indent=2), "utf-8")
_check_call( _check_call(
_LIB.XGBoosterPredictFromArrayInterface( _LIB.XGBoosterPredictFromCudaArray(
self.handle, self.handle,
interface_str, interface_str,
from_pystr_to_cstr(json.dumps(args)), from_pystr_to_cstr(json.dumps(args)),
@ -1788,7 +1837,7 @@ class Booster(object):
interfaces_str = _cudf_array_interfaces(data) interfaces_str = _cudf_array_interfaces(data)
_check_call( _check_call(
_LIB.XGBoosterPredictFromArrayInterfaceColumns( _LIB.XGBoosterPredictFromCudaColumnar(
self.handle, self.handle,
interfaces_str, interfaces_str,
from_pystr_to_cstr(json.dumps(args)), from_pystr_to_cstr(json.dumps(args)),

View File

@ -254,6 +254,7 @@ class DaskDMatrix:
raise TypeError(_expect((dd.DataFrame, da.Array, dd.Series), type(label))) raise TypeError(_expect((dd.DataFrame, da.Array, dd.Series), type(label)))
self._n_cols = data.shape[1] self._n_cols = data.shape[1]
assert isinstance(self._n_cols, int)
self.worker_map: Dict[str, "distributed.Future"] = defaultdict(list) self.worker_map: Dict[str, "distributed.Future"] = defaultdict(list)
self.is_quantile: bool = False self.is_quantile: bool = False
@ -881,7 +882,7 @@ async def _train_async(
return list(filter(lambda ret: ret is not None, results))[0] return list(filter(lambda ret: ret is not None, results))[0]
def train( def train( # pylint: disable=unused-argument
client: "distributed.Client", client: "distributed.Client",
params: Dict[str, Any], params: Dict[str, Any],
dtrain: DaskDMatrix, dtrain: DaskDMatrix,
@ -892,16 +893,17 @@ def train(
early_stopping_rounds: Optional[int] = None, early_stopping_rounds: Optional[int] = None,
xgb_model: Optional[Booster] = None, xgb_model: Optional[Booster] = None,
verbose_eval: Union[int, bool] = True, verbose_eval: Union[int, bool] = True,
callbacks: Optional[List[TrainingCallback]] = None callbacks: Optional[List[TrainingCallback]] = None,
) -> Any: ) -> Any:
'''Train XGBoost model. """Train XGBoost model.
.. versionadded:: 1.0.0 .. versionadded:: 1.0.0
.. note:: .. note::
Other parameters are the same as `xgboost.train` except for `evals_result`, which Other parameters are the same as :py:func:`xgboost.train` except for
is returned as part of function return value instead of argument. `evals_result`, which is returned as part of function return value instead of
argument.
Parameters Parameters
---------- ----------
@ -920,29 +922,17 @@ def train(
{'booster': xgboost.Booster, {'booster': xgboost.Booster,
'history': {'train': {'logloss': ['0.48253', '0.35953']}, 'history': {'train': {'logloss': ['0.48253', '0.35953']},
'eval': {'logloss': ['0.480385', '0.357756']}}} 'eval': {'logloss': ['0.480385', '0.357756']}}}
'''
"""
_assert_dask_support() _assert_dask_support()
client = _xgb_get_client(client) client = _xgb_get_client(client)
# Get global configuration before transferring computation to another thread or # Get global configuration before transferring computation to another thread or
# process. # process.
global_config = config.get_config() return client.sync(_train_async, global_config=config.get_config(), **locals())
return client.sync(_train_async,
client=client,
global_config=global_config,
num_boost_round=num_boost_round,
obj=obj,
feval=feval,
params=params,
dtrain=dtrain,
evals=evals,
early_stopping_rounds=early_stopping_rounds,
verbose_eval=verbose_eval,
xgb_model=xgb_model,
callbacks=callbacks)
def _can_output_df(data: _DaskCollection, output_shape: Tuple) -> bool: def _can_output_df(is_df: bool, output_shape: Tuple) -> bool:
return isinstance(data, dd.DataFrame) and len(output_shape) <= 2 return is_df and len(output_shape) <= 2
async def _direct_predict_impl( async def _direct_predict_impl(
@ -954,8 +944,9 @@ async def _direct_predict_impl(
meta: Dict[int, str], meta: Dict[int, str],
) -> _DaskCollection: ) -> _DaskCollection:
columns = list(meta.keys()) columns = list(meta.keys())
if _can_output_df(data, output_shape): if _can_output_df(isinstance(data, dd.DataFrame), output_shape):
if base_margin is not None and isinstance(base_margin, da.Array): if base_margin is not None and isinstance(base_margin, da.Array):
# Easier for map_partitions
base_margin_df: Optional[dd.DataFrame] = base_margin.to_dask_dataframe() base_margin_df: Optional[dd.DataFrame] = base_margin.to_dask_dataframe()
else: else:
base_margin_df = base_margin base_margin_df = base_margin
@ -975,16 +966,20 @@ async def _direct_predict_impl(
if base_margin is not None and isinstance( if base_margin is not None and isinstance(
base_margin, (dd.Series, dd.DataFrame) base_margin, (dd.Series, dd.DataFrame)
): ):
# Easier for map_blocks
base_margin_array: Optional[da.Array] = base_margin.to_dask_array() base_margin_array: Optional[da.Array] = base_margin.to_dask_array()
else: else:
base_margin_array = base_margin base_margin_array = base_margin
# Input data is 2-dim array, output can be 1(reg, binary)/2(multi-class, # Input data is 2-dim array, output can be 1(reg, binary)/2(multi-class,
# contrib)/3(contrib)/4(interaction) dims. # contrib)/3(contrib, interaction)/4(interaction) dims.
if len(output_shape) == 1: if len(output_shape) == 1:
drop_axis: Union[int, List[int]] = [1] # drop from 2 to 1 dim. drop_axis: Union[int, List[int]] = [1] # drop from 2 to 1 dim.
new_axis: Union[int, List[int]] = [] new_axis: Union[int, List[int]] = []
else: else:
drop_axis = [] drop_axis = []
if isinstance(data, dd.DataFrame):
new_axis = list(range(len(output_shape) - 2))
else:
new_axis = [i + 2 for i in range(len(output_shape) - 2)] new_axis = [i + 2 for i in range(len(output_shape) - 2)]
predictions = da.map_blocks( predictions = da.map_blocks(
mapped_predict, mapped_predict,
@ -1001,28 +996,21 @@ async def _direct_predict_impl(
def _infer_predict_output( def _infer_predict_output(
booster: Booster, booster: Booster, features: int, is_df: bool, inplace: bool, **kwargs: Any
data: Union[DaskDMatrix, _DaskCollection],
inplace: bool,
**kwargs: Any
) -> Tuple[Tuple[int, ...], Dict[int, str]]: ) -> Tuple[Tuple[int, ...], Dict[int, str]]:
"""Create a dummy test sample to infer output shape for prediction.""" """Create a dummy test sample to infer output shape for prediction."""
if isinstance(data, DaskDMatrix): assert isinstance(features, int)
features = data.num_col()
else:
features = data.shape[1]
rng = numpy.random.RandomState(1994) rng = numpy.random.RandomState(1994)
test_sample = rng.randn(1, features) test_sample = rng.randn(1, features)
if inplace: if inplace:
# clear the state to avoid gpu_id, gpu_predictor kwargs = kwargs.copy()
booster = Booster(model_file=booster.save_raw()) if kwargs.pop("predict_type") == "margin":
test_predt = booster.inplace_predict(test_sample, **kwargs) kwargs["output_margin"] = True
else:
m = DMatrix(test_sample) m = DMatrix(test_sample)
test_predt = booster.predict(m, **kwargs) test_predt = booster.predict(m, validate_features=False, **kwargs)
n_columns = test_predt.shape[1] if len(test_predt.shape) > 1 else 1 n_columns = test_predt.shape[1] if len(test_predt.shape) > 1 else 1
meta: Dict[int, str] = {} meta: Dict[int, str] = {}
if _can_output_df(data, test_predt.shape): if _can_output_df(is_df, test_predt.shape):
for i in range(n_columns): for i in range(n_columns):
meta[i] = "f4" meta[i] = "f4"
return test_predt.shape, meta return test_predt.shape, meta
@ -1034,7 +1022,7 @@ async def _get_model_future(
if isinstance(model, Booster): if isinstance(model, Booster):
booster = await client.scatter(model, broadcast=True) booster = await client.scatter(model, broadcast=True)
elif isinstance(model, dict): elif isinstance(model, dict):
booster = await client.scatter(model["booster"]) booster = await client.scatter(model["booster"], broadcast=True)
elif isinstance(model, distributed.Future): elif isinstance(model, distributed.Future):
booster = model booster = model
if booster.type is not Booster: if booster.type is not Booster:
@ -1059,6 +1047,8 @@ async def _predict_async(
approx_contribs: bool, approx_contribs: bool,
pred_interactions: bool, pred_interactions: bool,
validate_features: bool, validate_features: bool,
iteration_range: Tuple[int, int],
strict_shape: bool,
) -> _DaskCollection: ) -> _DaskCollection:
_booster = await _get_model_future(client, model) _booster = await _get_model_future(client, model)
if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)): if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
@ -1077,43 +1067,51 @@ async def _predict_async(
approx_contribs=approx_contribs, approx_contribs=approx_contribs,
pred_interactions=pred_interactions, pred_interactions=pred_interactions,
validate_features=validate_features, validate_features=validate_features,
iteration_range=iteration_range,
strict_shape=strict_shape,
) )
if is_df and len(predt.shape) <= 2: if _can_output_df(is_df, predt.shape):
if lazy_isinstance(partition, "cudf", "core.dataframe.DataFrame"): if lazy_isinstance(partition, "cudf", "core.dataframe.DataFrame"):
import cudf import cudf
predt = cudf.DataFrame(predt, columns=columns) predt = cudf.DataFrame(predt, columns=columns, dtype=numpy.float32)
else: else:
predt = DataFrame(predt, columns=columns) predt = DataFrame(predt, columns=columns, dtype=numpy.float32)
return predt return predt
# Predict on dask collection directly. # Predict on dask collection directly.
if isinstance(data, (da.Array, dd.DataFrame)): if isinstance(data, (da.Array, dd.DataFrame)):
_output_shape, meta = _infer_predict_output( _output_shape, meta = await client.compute(
await _booster.result(), client.submit(
data, _infer_predict_output,
_booster,
features=data.shape[1],
is_df=isinstance(data, dd.DataFrame),
inplace=False, inplace=False,
output_margin=output_margin, output_margin=output_margin,
pred_leaf=pred_leaf, pred_leaf=pred_leaf,
pred_contribs=pred_contribs, pred_contribs=pred_contribs,
approx_contribs=approx_contribs, approx_contribs=approx_contribs,
pred_interactions=pred_interactions, pred_interactions=pred_interactions,
validate_features=False, )
) )
return await _direct_predict_impl( return await _direct_predict_impl(
mapped_predict, _booster, data, None, _output_shape, meta mapped_predict, _booster, data, None, _output_shape, meta
) )
output_shape, _ = _infer_predict_output( output_shape, _ = await client.compute(
booster=await _booster.result(), client.submit(
data=data, _infer_predict_output,
booster=_booster,
features=data.num_col(),
is_df=False,
inplace=False, inplace=False,
output_margin=output_margin, output_margin=output_margin,
pred_leaf=pred_leaf, pred_leaf=pred_leaf,
pred_contribs=pred_contribs, pred_contribs=pred_contribs,
approx_contribs=approx_contribs, approx_contribs=approx_contribs,
pred_interactions=pred_interactions, pred_interactions=pred_interactions,
validate_features=False, )
) )
# Prediction on dask DMatrix. # Prediction on dask DMatrix.
partition_order = data.partition_order partition_order = data.partition_order
@ -1180,7 +1178,7 @@ async def _predict_async(
futures[i], shape=(rows,) + output_shape[1:], dtype=numpy.float32 futures[i], shape=(rows,) + output_shape[1:], dtype=numpy.float32
) )
) )
predictions = await da.concatenate(arrays, axis=0) predictions = da.concatenate(arrays, axis=0)
return predictions return predictions
@ -1194,15 +1192,19 @@ def predict( # pylint: disable=unused-argument
pred_contribs: bool = False, pred_contribs: bool = False,
approx_contribs: bool = False, approx_contribs: bool = False,
pred_interactions: bool = False, pred_interactions: bool = False,
validate_features: bool = True validate_features: bool = True,
iteration_range: Tuple[int, int] = (0, 0),
strict_shape: bool = False,
) -> Any: ) -> Any:
'''Run prediction with a trained booster. '''Run prediction with a trained booster.
.. note:: .. note::
Using ``inplace_predict `` might be faster when meta information like Using ``inplace_predict`` might be faster when some features are not needed. See
``base_margin`` is not needed. For other parameters, please see :py:meth:`xgboost.Booster.predict` for details on various parameters. When using
``Booster.predict``. ``pred_interactions`` with mutli-class model, input should be ``da.Array`` or
``DaskDMatrix`` due to limitation in ``da.map_blocks``.
.. versionadded:: 1.0.0 .. versionadded:: 1.0.0
@ -1232,57 +1234,68 @@ def predict( # pylint: disable=unused-argument
''' '''
_assert_dask_support() _assert_dask_support()
client = _xgb_get_client(client) client = _xgb_get_client(client)
return client.sync( return client.sync(_predict_async, global_config=config.get_config(), **locals())
_predict_async, global_config=config.get_config(), **locals()
)
async def _inplace_predict_async( async def _inplace_predict_async( # pylint: disable=too-many-branches
client: "distributed.Client", client: "distributed.Client",
global_config: Dict[str, Any], global_config: Dict[str, Any],
model: Union[Booster, Dict, "distributed.Future"], model: Union[Booster, Dict, "distributed.Future"],
data: _DaskCollection, data: _DaskCollection,
iteration_range: Tuple[int, int] = (0, 0), iteration_range: Tuple[int, int],
predict_type: str = 'value', predict_type: str,
missing: float = numpy.nan missing: float,
validate_features: bool,
base_margin: Optional[_DaskCollection],
strict_shape: bool,
) -> _DaskCollection: ) -> _DaskCollection:
client = _xgb_get_client(client) client = _xgb_get_client(client)
booster = await _get_model_future(client, model) booster = await _get_model_future(client, model)
if not isinstance(data, (da.Array, dd.DataFrame)): if not isinstance(data, (da.Array, dd.DataFrame)):
raise TypeError(_expect([da.Array, dd.DataFrame], type(data))) raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))
if base_margin is not None and not isinstance(
data, (da.Array, dd.DataFrame, dd.Series)
):
raise TypeError(_expect([da.Array, dd.DataFrame, dd.Series], type(base_margin)))
def mapped_predict( def mapped_predict(
booster: Booster, data: Any, is_df: bool, columns: List[int], _: Any booster: Booster, data: Any, is_df: bool, columns: List[int], base_margin: Any
) -> Any: ) -> Any:
with config.config_context(**global_config): with config.config_context(**global_config):
prediction = booster.inplace_predict( prediction = booster.inplace_predict(
data, data,
iteration_range=iteration_range, iteration_range=iteration_range,
predict_type=predict_type, predict_type=predict_type,
missing=missing missing=missing,
base_margin=base_margin,
validate_features=validate_features,
strict_shape=strict_shape,
) )
if is_df and len(prediction.shape) <= 2: if _can_output_df(is_df, prediction.shape):
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'): if lazy_isinstance(data, "cudf.core.dataframe", "DataFrame"):
import cudf import cudf
prediction = cudf.DataFrame( prediction = cudf.DataFrame(
prediction, columns=columns, dtype=numpy.float32 prediction, columns=columns, dtype=numpy.float32
) )
else: else:
# If it's from pandas, the partition is a numpy array # If it's from pandas, the partition is a numpy array
prediction = DataFrame( prediction = DataFrame(prediction, columns=columns, dtype=numpy.float32)
prediction, columns=columns, dtype=numpy.float32
)
return prediction return prediction
# await turns future into value.
shape, meta = _infer_predict_output( shape, meta = await client.compute(
await booster.result(), client.submit(
data, _infer_predict_output,
True, booster,
features=data.shape[1],
is_df=isinstance(data, dd.DataFrame),
inplace=True,
predict_type=predict_type, predict_type=predict_type,
iteration_range=iteration_range iteration_range=iteration_range,
)
) )
return await _direct_predict_impl( return await _direct_predict_impl(
mapped_predict, booster, data, None, shape, meta mapped_predict, booster, data, base_margin, shape, meta
) )
@ -1291,10 +1304,13 @@ def inplace_predict( # pylint: disable=unused-argument
model: Union[TrainReturnT, Booster, "distributed.Future"], model: Union[TrainReturnT, Booster, "distributed.Future"],
data: _DaskCollection, data: _DaskCollection,
iteration_range: Tuple[int, int] = (0, 0), iteration_range: Tuple[int, int] = (0, 0),
predict_type: str = 'value', predict_type: str = "value",
missing: float = numpy.nan missing: float = numpy.nan,
validate_features: bool = True,
base_margin: Optional[_DaskCollection] = None,
strict_shape: bool = False,
) -> Any: ) -> Any:
'''Inplace prediction. """Inplace prediction. See doc in :py:meth:`xgboost.Booster.inplace_predict` for details.
.. versionadded:: 1.1.0 .. versionadded:: 1.1.0
@ -1304,16 +1320,27 @@ def inplace_predict( # pylint: disable=unused-argument
Specify the dask client used for training. Use default client Specify the dask client used for training. Use default client
returned from dask if it's set to None. returned from dask if it's set to None.
model: model:
The trained model. It can be a distributed.Future so user can See :py:func:`xgboost.dask.predict` for details.
pre-scatter it onto all workers. data :
dask collection.
iteration_range: iteration_range:
Specify the range of trees used for prediction. See :py:meth:`xgboost.Booster.predict` for details.
predict_type: predict_type:
* 'value': Normal prediction result. See :py:meth:`xgboost.Booster.inplace_predict` for details.
* 'margin': Output the raw untransformed margin value.
missing: missing:
Value in the input data which needs to be present as a missing Value in the input data which needs to be present as a missing
value. If None, defaults to np.nan. value. If None, defaults to np.nan.
base_margin:
See :py:obj:`xgboost.DMatrix` for details. Right now classifier is not well
supported with base_margin as it requires the size of base margin to be `n_classes
* n_samples`.
.. versionadded:: 1.4.0
strict_shape:
See :py:meth:`xgboost.Booster.predict` for details.
.. versionadded:: 1.4.0
Returns Returns
------- -------
@ -1322,7 +1349,7 @@ def inplace_predict( # pylint: disable=unused-argument
data is ``dask.dataframe.DataFrame``, return value can be data is ``dask.dataframe.DataFrame``, return value can be
``dask.dataframe.Series``, ``dask.dataframe.DataFrame`` or ``dask.array.Array``, ``dask.dataframe.Series``, ``dask.dataframe.DataFrame`` or ``dask.array.Array``,
depending on the output shape. depending on the output shape.
''' """
_assert_dask_support() _assert_dask_support()
client = _xgb_get_client(client) client = _xgb_get_client(client)
return client.sync( return client.sync(
@ -1334,9 +1361,11 @@ async def _async_wrap_evaluation_matrices(
client: "distributed.Client", **kwargs: Any client: "distributed.Client", **kwargs: Any
) -> Tuple[DaskDMatrix, Optional[List[Tuple[DaskDMatrix, str]]]]: ) -> Tuple[DaskDMatrix, Optional[List[Tuple[DaskDMatrix, str]]]]:
"""A switch function for async environment.""" """A switch function for async environment."""
def _inner(**kwargs: Any) -> DaskDMatrix: def _inner(**kwargs: Any) -> DaskDMatrix:
m = DaskDMatrix(client=client, **kwargs) m = DaskDMatrix(client=client, **kwargs)
return m return m
train_dmatrix, evals = _wrap_evaluation_matrices(create_dmatrix=_inner, **kwargs) train_dmatrix, evals = _wrap_evaluation_matrices(create_dmatrix=_inner, **kwargs)
train_dmatrix = await train_dmatrix train_dmatrix = await train_dmatrix
if evals is None: if evals is None:
@ -1351,25 +1380,45 @@ async def _async_wrap_evaluation_matrices(
class DaskScikitLearnBase(XGBModel): class DaskScikitLearnBase(XGBModel):
'''Base class for implementing scikit-learn interface with Dask''' """Base class for implementing scikit-learn interface with Dask"""
_client = None _client = None
async def _predict_async( async def _predict_async(
self, data: _DaskCollection, self,
output_margin: bool = False, data: _DaskCollection,
validate_features: bool = True, output_margin: bool,
base_margin: Optional[_DaskCollection] = None validate_features: bool,
base_margin: Optional[_DaskCollection],
iteration_range: Optional[Tuple[int, int]],
) -> Any: ) -> Any:
test_dmatrix = await DaskDMatrix( iteration_range = self._get_iteration_range(iteration_range)
client=self.client, data=data, base_margin=base_margin, if self._can_use_inplace_predict():
missing=self.missing predts = await inplace_predict(
client=self.client,
model=self.get_booster(),
data=data,
iteration_range=iteration_range,
predict_type="margin" if output_margin else "value",
missing=self.missing,
base_margin=base_margin,
validate_features=validate_features,
) )
pred_probs = await predict(client=self.client, if isinstance(predts, dd.DataFrame):
model=self.get_booster(), data=test_dmatrix, predts = predts.to_dask_array()
else:
test_dmatrix = await DaskDMatrix(
self.client, data=data, base_margin=base_margin, missing=self.missing
)
predts = await predict(
self.client,
model=self.get_booster(),
data=test_dmatrix,
output_margin=output_margin, output_margin=output_margin,
validate_features=validate_features) validate_features=validate_features,
return pred_probs iteration_range=iteration_range,
)
return predts
def predict( def predict(
self, self,
@ -1377,26 +1426,56 @@ class DaskScikitLearnBase(XGBModel):
output_margin: bool = False, output_margin: bool = False,
ntree_limit: Optional[int] = None, ntree_limit: Optional[int] = None,
validate_features: bool = True, validate_features: bool = True,
base_margin: Optional[_DaskCollection] = None base_margin: Optional[_DaskCollection] = None,
iteration_range: Optional[Tuple[int, int]] = None,
) -> Any: ) -> Any:
_assert_dask_support() _assert_dask_support()
msg = '`ntree_limit` is not supported on dask, use model slicing instead.' msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead."
assert ntree_limit is None, msg assert ntree_limit is None, msg
return self.client.sync( return self.client.sync(
self._predict_async, self._predict_async,
X, X,
output_margin=output_margin, output_margin=output_margin,
validate_features=validate_features, validate_features=validate_features,
base_margin=base_margin base_margin=base_margin,
iteration_range=iteration_range,
) )
async def _apply_async(
self,
X: _DaskCollection,
iteration_range: Optional[Tuple[int, int]] = None,
) -> Any:
iteration_range = self._get_iteration_range(iteration_range)
test_dmatrix = await DaskDMatrix(self.client, data=X, missing=self.missing)
predts = await predict(
self.client,
model=self.get_booster(),
data=test_dmatrix,
pred_leaf=True,
iteration_range=iteration_range,
)
return predts
def apply(
self,
X: _DaskCollection,
ntree_limit: Optional[int] = None,
iteration_range: Optional[Tuple[int, int]] = None,
) -> Any:
_assert_dask_support()
msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead."
assert ntree_limit is None, msg
return self.client.sync(self._apply_async, X, iteration_range=iteration_range)
def __await__(self) -> Awaitable[Any]: def __await__(self) -> Awaitable[Any]:
# Generate a coroutine wrapper to make this class awaitable. # Generate a coroutine wrapper to make this class awaitable.
async def _() -> Awaitable[Any]: async def _() -> Awaitable[Any]:
return self return self
return self.client.sync(_).__await__() return self.client.sync(_).__await__()
def __getstate__(self): def __getstate__(self) -> Dict:
this = self.__dict__.copy() this = self.__dict__.copy()
if "_client" in this.keys(): if "_client" in this.keys():
del this["_client"] del this["_client"]
@ -1404,7 +1483,7 @@ class DaskScikitLearnBase(XGBModel):
@property @property
def client(self) -> "distributed.Client": def client(self) -> "distributed.Client":
'''The dask client used in this model.''' """The dask client used in this model."""
client = _xgb_get_client(self._client) client = _xgb_get_client(self._client)
return client return client
@ -1494,7 +1573,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
sample_weight_eval_set: Optional[List[_DaskCollection]] = None, sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
base_margin_eval_set: Optional[List[_DaskCollection]] = None, base_margin_eval_set: Optional[List[_DaskCollection]] = None,
feature_weights: Optional[_DaskCollection] = None, feature_weights: Optional[_DaskCollection] = None,
callbacks: Optional[List[TrainingCallback]] = None callbacks: Optional[List[TrainingCallback]] = None,
) -> "DaskXGBRegressor": ) -> "DaskXGBRegressor":
_assert_dask_support() _assert_dask_support()
args = {k: v for k, v in locals().items() if k != "self"} args = {k: v for k, v in locals().items() if k != "self"}
@ -1556,9 +1635,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
else: else:
obj = None obj = None
model, metric, params = self._configure_fit( model, metric, params = self._configure_fit(
booster=xgb_model, booster=xgb_model, eval_metric=eval_metric, params=params
eval_metric=eval_metric,
params=params
) )
results = await train( results = await train(
client=self.client, client=self.client,
@ -1610,18 +1687,19 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
X: _DaskCollection, X: _DaskCollection,
validate_features: bool, validate_features: bool,
output_margin: bool, output_margin: bool,
base_margin: Optional[_DaskCollection] base_margin: Optional[_DaskCollection],
iteration_range: Optional[Tuple[int, int]],
) -> _DaskCollection: ) -> _DaskCollection:
test_dmatrix = await DaskDMatrix( if iteration_range is None:
client=self.client, data=X, base_margin=base_margin, iteration_range = (0, 0)
missing=self.missing predts = await super()._predict_async(
) data=X,
pred_probs = await predict(client=self.client, output_margin=output_margin,
model=self.get_booster(),
data=test_dmatrix,
validate_features=validate_features, validate_features=validate_features,
output_margin=output_margin) base_margin=base_margin,
return _cls_predict_proba(self.objective, pred_probs, da.vstack) iteration_range=iteration_range,
)
return _cls_predict_proba(self.objective, predts, da.vstack)
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
def predict_proba( def predict_proba(
@ -1630,37 +1708,49 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
ntree_limit: Optional[int] = None, ntree_limit: Optional[int] = None,
validate_features: bool = True, validate_features: bool = True,
output_margin: bool = False, output_margin: bool = False,
base_margin: Optional[_DaskCollection] = None base_margin: Optional[_DaskCollection] = None,
iteration_range: Optional[Tuple[int, int]] = None,
) -> Any: ) -> Any:
_assert_dask_support() _assert_dask_support()
msg = '`ntree_limit` is not supported on dask, use model slicing instead.' msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead."
assert ntree_limit is None, msg assert ntree_limit is None, msg
return self.client.sync( return self.client.sync(
self._predict_proba_async, self._predict_proba_async,
X=X, X=X,
validate_features=validate_features, validate_features=validate_features,
output_margin=output_margin, output_margin=output_margin,
base_margin=base_margin base_margin=base_margin,
iteration_range=iteration_range,
) )
predict_proba.__doc__ = XGBClassifier.predict_proba.__doc__ predict_proba.__doc__ = XGBClassifier.predict_proba.__doc__
async def _predict_async( async def _predict_async(
self, data: _DaskCollection, self,
output_margin: bool = False, data: _DaskCollection,
validate_features: bool = True, output_margin: bool,
base_margin: Optional[_DaskCollection] = None validate_features: bool,
base_margin: Optional[_DaskCollection],
iteration_range: Optional[Tuple[int, int]],
) -> _DaskCollection: ) -> _DaskCollection:
pred_probs = await super()._predict_async( pred_probs = await super()._predict_async(
data, output_margin, validate_features, base_margin data, output_margin, validate_features, base_margin, iteration_range
) )
if output_margin: if output_margin:
return pred_probs return pred_probs
if self.n_classes_ == 2: if len(pred_probs.shape) == 1:
preds = (pred_probs > 0.5).astype(int) preds = (pred_probs > 0.5).astype(int)
else: else:
preds = da.argmax(pred_probs, axis=1) assert len(pred_probs.shape) == 2
assert isinstance(pred_probs, da.Array)
# when using da.argmax directly, dask will construct a numpy based return
# array, which runs into error when computing GPU based prediction.
def _argmax(x: Any) -> Any:
return x.argmax(axis=1)
preds = da.map_blocks(_argmax, pred_probs, drop_axis=1)
return preds return preds
@ -1770,7 +1860,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
callbacks: Optional[List[TrainingCallback]] = None callbacks: Optional[List[TrainingCallback]] = None
) -> "DaskXGBRanker": ) -> "DaskXGBRanker":
_assert_dask_support() _assert_dask_support()
args = {k: v for k, v in locals().items() if k != 'self'} args = {k: v for k, v in locals().items() if k != "self"}
return self.client.sync(self._fit_async, **args) return self.client.sync(self._fit_async, **args)
# FIXME(trivialfis): arguments differ due to additional parameters like group and qid. # FIXME(trivialfis): arguments differ due to additional parameters like group and qid.

View File

@ -6,7 +6,8 @@ import warnings
import json import json
from typing import Union, Optional, List, Dict, Callable, Tuple, Any from typing import Union, Optional, List, Dict, Callable, Tuple, Any
import numpy as np import numpy as np
from .core import Booster, DMatrix, XGBoostError, _deprecate_positional_args from .core import Booster, DMatrix, XGBoostError
from .core import _deprecate_positional_args, _convert_ntree_limit
from .core import Metric from .core import Metric
from .training import train from .training import train
from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array
@ -413,8 +414,8 @@ class XGBModel(XGBModelBase):
# Simple optimization to gain speed (inspect is slow) # Simple optimization to gain speed (inspect is slow)
return self return self
# this concatenates kwargs into paraemters, enabling `get_params` for # this concatenates kwargs into parameters, enabling `get_params` for
# obtaining parameters from keyword paraemters. # obtaining parameters from keyword parameters.
for key, value in params.items(): for key, value in params.items():
if hasattr(self, key): if hasattr(self, key):
setattr(self, key, value) setattr(self, key, value)
@ -747,26 +748,45 @@ class XGBModel(XGBModelBase):
self._set_evaluation_result(evals_result) self._set_evaluation_result(evals_result)
return self return self
def _can_use_inplace_predict(self) -> bool:
# When predictor is explicitly set, using `inplace_predict` might result into
# error with incompatible data type.
# Inplace predict doesn't handle as many data types as DMatrix, but it's
# sufficient for dask interface where input is simpiler.
params = self.get_params()
booster = self.booster
if params.get("predictor", None) is None and (
booster is None or booster == "gbtree"
):
return True
return False
def _get_iteration_range(
self, iteration_range: Optional[Tuple[int, int]]
) -> Tuple[int, int]:
if (iteration_range is None or iteration_range[1] == 0):
# Use best_iteration if defined.
try:
iteration_range = (0, self.best_iteration + 1)
except AttributeError:
iteration_range = (0, 0)
if self.booster == "gblinear":
iteration_range = (0, 0)
return iteration_range
def predict( def predict(
self, self,
X, X,
output_margin=False, output_margin=False,
ntree_limit=None, ntree_limit=None,
validate_features=True, validate_features=True,
base_margin=None base_margin=None,
iteration_range=None,
): ):
""" """
Predict with `X`. Predict with `X`.
.. note:: This function is not thread safe. .. note:: This function is only thread safe for `gbtree`
For each booster object, predict can only be called from one thread.
If you want to run prediction using multiple thread, call ``xgb.copy()`` to make copies
of model object and then call ``predict()``.
.. code-block:: python
preds = bst.predict(dtest, ntree_limit=num_round)
Parameters Parameters
---------- ----------
@ -775,37 +795,40 @@ class XGBModel(XGBModelBase):
output_margin : bool output_margin : bool
Whether to output the raw untransformed margin value. Whether to output the raw untransformed margin value.
ntree_limit : int ntree_limit : int
Limit number of trees in the prediction; defaults to best_ntree_limit if Deprecated, use `iteration_range` instead.
defined (i.e. it has been trained with early stopping), otherwise 0 (use all
trees).
validate_features : bool validate_features : bool
When this is True, validate that the Booster's and data's feature_names are identical. When this is True, validate that the Booster's and data's feature_names are
Otherwise, it is assumed that the feature_names are the same. identical. Otherwise, it is assumed that the feature_names are the same.
base_margin : array_like base_margin : array_like
Margin added to prediction. Margin added to prediction.
iteration_range :
Specifies which layer of trees are used in prediction. For example, if a
random forest is trained with 100 rounds. Specifying `iteration_range=(10,
20)`, then only the forests built during [10, 20) (half open set) rounds are
used in this prediction.
.. versionadded:: 1.4.0
Returns Returns
------- -------
prediction : numpy array prediction : numpy array
""" """
# pylint: disable=missing-docstring,invalid-name iteration_range = _convert_ntree_limit(
test_dmatrix = DMatrix(X, base_margin=base_margin, self.get_booster(), ntree_limit, iteration_range
missing=self.missing, nthread=self.n_jobs) )
# get ntree_limit to use - if none specified, default to iteration_range = self._get_iteration_range(iteration_range)
# best_ntree_limit if defined, otherwise 0. test = DMatrix(
if ntree_limit is None: X, base_margin=base_margin, missing=self.missing, nthread=self.n_jobs
try: )
ntree_limit = self.best_ntree_limit
except AttributeError:
ntree_limit = 0
return self.get_booster().predict( return self.get_booster().predict(
test_dmatrix, data=test,
iteration_range=iteration_range,
output_margin=output_margin, output_margin=output_margin,
ntree_limit=ntree_limit, validate_features=validate_features,
validate_features=validate_features
) )
def apply(self, X, ntree_limit=0): def apply(
self, X, ntree_limit: int = 0, iteration_range: Optional[Tuple[int, int]] = None
) -> np.ndarray:
"""Return the predicted leaf every tree for each sample. """Return the predicted leaf every tree for each sample.
Parameters Parameters
@ -823,10 +846,16 @@ class XGBModel(XGBModelBase):
leaf x ends up in. Leaves are numbered within leaf x ends up in. Leaves are numbered within
``[0; 2**(self.max_depth+1))``, possibly with gaps in the numbering. ``[0; 2**(self.max_depth+1))``, possibly with gaps in the numbering.
""" """
iteration_range = _convert_ntree_limit(
self.get_booster(), ntree_limit, iteration_range
)
iteration_range = self._get_iteration_range(iteration_range)
test_dmatrix = DMatrix(X, missing=self.missing, nthread=self.n_jobs) test_dmatrix = DMatrix(X, missing=self.missing, nthread=self.n_jobs)
return self.get_booster().predict(test_dmatrix, return self.get_booster().predict(
test_dmatrix,
pred_leaf=True, pred_leaf=True,
ntree_limit=ntree_limit) iteration_range=iteration_range
)
def evals_result(self): def evals_result(self):
"""Return the evaluation results. """Return the evaluation results.
@ -945,8 +974,7 @@ class XGBModel(XGBModelBase):
'Coefficients are not defined for Booster type {}' 'Coefficients are not defined for Booster type {}'
.format(self.booster)) .format(self.booster))
b = self.get_booster() b = self.get_booster()
coef = np.array(json.loads( coef = np.array(json.loads(b.get_dump(dump_format='json')[0])['weight'])
b.get_dump(dump_format='json')[0])['weight'])
# Logic for multiclass classification # Logic for multiclass classification
n_classes = getattr(self, 'n_classes_', None) n_classes = getattr(self, 'n_classes_', None)
if n_classes is not None: if n_classes is not None:
@ -1157,14 +1185,16 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
output_margin=False, output_margin=False,
ntree_limit=None, ntree_limit=None,
validate_features=True, validate_features=True,
base_margin=None base_margin=None,
iteration_range: Optional[Tuple[int, int]] = None,
): ):
class_probs = super().predict( class_probs = super().predict(
X=X, X=X,
output_margin=output_margin, output_margin=output_margin,
ntree_limit=ntree_limit, ntree_limit=ntree_limit,
validate_features=validate_features, validate_features=validate_features,
base_margin=base_margin base_margin=base_margin,
iteration_range=iteration_range,
) )
if output_margin: if output_margin:
# If output_margin is active, simply return the scores # If output_margin is active, simply return the scores
@ -1180,29 +1210,34 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
return self._le.inverse_transform(column_indexes) return self._le.inverse_transform(column_indexes)
return column_indexes return column_indexes
def predict_proba(self, X, ntree_limit=None, validate_features=False, def predict_proba(
base_margin=None): self,
X,
ntree_limit=None,
validate_features=False,
base_margin=None,
iteration_range: Optional[Tuple[int, int]] = None,
) -> np.ndarray:
""" Predict the probability of each `X` example being of a given class. """ Predict the probability of each `X` example being of a given class.
.. note:: This function is not thread safe .. note:: This function is only thread safe for `gbtree`
For each booster object, predict can only be called from one
thread. If you want to run prediction using multiple thread, call
``xgb.copy()`` to make copies of model object and then call predict
Parameters Parameters
---------- ----------
X : array_like X : array_like
Feature matrix. Feature matrix.
ntree_limit : int ntree_limit : int
Limit number of trees in the prediction; defaults to best_ntree_limit if Deprecated, use `iteration_range` instead.
defined (i.e. it has been trained with early stopping), otherwise 0 (use all
trees).
validate_features : bool validate_features : bool
When this is True, validate that the Booster's and data's feature_names are When this is True, validate that the Booster's and data's feature_names are
identical. Otherwise, it is assumed that the feature_names are the same. identical. Otherwise, it is assumed that the feature_names are the same.
base_margin : array_like base_margin : array_like
Margin added to prediction. Margin added to prediction.
iteration_range :
Specifies which layer of trees are used in prediction. For example, if a
random forest is trained with 100 rounds. Specifying `iteration_range=(10,
20)`, then only the forests built during [10, 20) (half open set) rounds are
used in this prediction.
Returns Returns
------- -------
@ -1215,7 +1250,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
output_margin=False, output_margin=False,
ntree_limit=ntree_limit, ntree_limit=ntree_limit,
validate_features=validate_features, validate_features=validate_features,
base_margin=base_margin base_margin=base_margin,
iteration_range=iteration_range
) )
return _cls_predict_proba(self.objective, class_probs, np.vstack) return _cls_predict_proba(self.objective, class_probs, np.vstack)

View File

@ -4,9 +4,8 @@
"""Training Library containing training routines.""" """Training Library containing training routines."""
import warnings import warnings
import copy import copy
import json
import numpy as np import numpy as np
from .core import Booster, XGBoostError from .core import Booster, XGBoostError, _get_booster_layer_trees
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold) from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
from . import callback from . import callback
@ -91,24 +90,7 @@ def _train_internal(params, dtrain,
# These should be moved into callback functions `after_training`, but until old # These should be moved into callback functions `after_training`, but until old
# callbacks are removed, the train function is the only place for setting the # callbacks are removed, the train function is the only place for setting the
# attributes. # attributes.
config = json.loads(bst.save_config()) num_parallel_tree, _ = _get_booster_layer_trees(bst)
booster = config['learner']['gradient_booster']['name']
if booster == 'gblinear':
num_parallel_tree = 0
elif booster == 'dart':
num_parallel_tree = int(
config['learner']['gradient_booster']['gbtree']['gbtree_train_param'][
'num_parallel_tree'
]
)
elif booster == 'gbtree':
num_parallel_tree = int(
config['learner']['gradient_booster']['gbtree_train_param'][
'num_parallel_tree']
)
else:
raise ValueError(f'Unknown booster: {booster}')
if bst.attr('best_score') is not None: if bst.attr('best_score') is not None:
bst.best_score = float(bst.attr('best_score')) bst.best_score = float(bst.attr('best_score'))
bst.best_iteration = int(bst.attr('best_iteration')) bst.best_iteration = int(bst.attr('best_iteration'))

View File

@ -619,20 +619,58 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
CHECK_HANDLE(); CHECK_HANDLE();
auto *learner = static_cast<Learner*>(handle); auto *learner = static_cast<Learner*>(handle);
auto& entry = learner->GetThreadLocal().prediction_entry; auto& entry = learner->GetThreadLocal().prediction_entry;
learner->Predict( auto iteration_end = GetIterationFromTreeLimit(ntree_limit, learner);
*static_cast<std::shared_ptr<DMatrix>*>(dmat), learner->Predict(*static_cast<std::shared_ptr<DMatrix> *>(dmat),
(option_mask & 1) != 0, (option_mask & 1) != 0, &entry.predictions, 0, iteration_end,
&entry.predictions, ntree_limit, static_cast<bool>(training), (option_mask & 2) != 0,
static_cast<bool>(training), (option_mask & 4) != 0, (option_mask & 8) != 0,
(option_mask & 2) != 0,
(option_mask & 4) != 0,
(option_mask & 8) != 0,
(option_mask & 16) != 0); (option_mask & 16) != 0);
*out_result = dmlc::BeginPtr(entry.predictions.ConstHostVector()); *out_result = dmlc::BeginPtr(entry.predictions.ConstHostVector());
*len = static_cast<xgboost::bst_ulong>(entry.predictions.Size()); *len = static_cast<xgboost::bst_ulong>(entry.predictions.Size());
API_END(); API_END();
} }
XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle,
DMatrixHandle dmat,
char const* c_json_config,
xgboost::bst_ulong const **out_shape,
xgboost::bst_ulong *out_dim,
bst_float const **out_result) {
API_BEGIN();
if (handle == nullptr) {
LOG(FATAL) << "Booster has not been intialized or has already been disposed.";
}
if (dmat == nullptr) {
LOG(FATAL) << "DMatrix has not been intialized or has already been disposed.";
}
auto config = Json::Load(StringView{c_json_config});
auto *learner = static_cast<Learner*>(handle);
auto& entry = learner->GetThreadLocal().prediction_entry;
auto p_m = *static_cast<std::shared_ptr<DMatrix> *>(dmat);
auto type = PredictionType(get<Integer const>(config["type"]));
auto iteration_begin = get<Integer const>(config["iteration_begin"]);
auto iteration_end = get<Integer const>(config["iteration_end"]);
learner->Predict(
*static_cast<std::shared_ptr<DMatrix> *>(dmat),
type == PredictionType::kMargin, &entry.predictions, iteration_begin,
iteration_end, get<Boolean const>(config["training"]),
type == PredictionType::kLeaf, type == PredictionType::kContribution,
type == PredictionType::kApproxContribution,
type == PredictionType::kInteraction);
*out_result = dmlc::BeginPtr(entry.predictions.ConstHostVector());
auto &shape = learner->GetThreadLocal().prediction_shape;
auto chunksize = p_m->Info().num_row_ == 0 ? 0 : entry.predictions.Size() / p_m->Info().num_row_;
auto rounds = iteration_end - iteration_begin;
rounds = rounds == 0 ? learner->BoostedRounds() : rounds;
// Determine shape
bool strict_shape = get<Boolean const>(config["strict_shape"]);
CalcPredictShape(strict_shape, type, p_m->Info().num_row_,
p_m->Info().num_col_, chunksize, learner->Groups(), rounds,
&shape, out_dim);
*out_shape = dmlc::BeginPtr(shape);
API_END();
}
template <typename T> template <typename T>
void InplacePredictImpl(std::shared_ptr<T> x, std::shared_ptr<DMatrix> p_m, void InplacePredictImpl(std::shared_ptr<T> x, std::shared_ptr<DMatrix> p_m,
@ -705,7 +743,7 @@ XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, char const *indptr,
} }
#if !defined(XGBOOST_USE_CUDA) #if !defined(XGBOOST_USE_CUDA)
XGB_DLL int XGBoosterPredictFromArrayInterface( XGB_DLL int XGBoosterPredictFromCUDAArray(
BoosterHandle handle, char const *c_json_strs, char const *c_json_config, BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
DMatrixHandle m, xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim, DMatrixHandle m, xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim,
const float **out_result) { const float **out_result) {
@ -715,7 +753,7 @@ XGB_DLL int XGBoosterPredictFromArrayInterface(
API_END(); API_END();
} }
XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns( XGB_DLL int XGBoosterPredictFromCUDAColumnar(
BoosterHandle handle, char const *c_json_strs, char const *c_json_config, BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
DMatrixHandle m, xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim, DMatrixHandle m, xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim,
const float **out_result) { const float **out_result) {

View File

@ -66,8 +66,7 @@ int InplacePreidctCuda(BoosterHandle handle, char const *c_json_strs,
API_END(); API_END();
} }
// A hidden API as cache id is not being supported yet. XGB_DLL int XGBoosterPredictFromCudaColumnar(
XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(
BoosterHandle handle, char const *c_json_strs, char const *c_json_config, BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
DMatrixHandle m, xgboost::bst_ulong const **out_shape, DMatrixHandle m, xgboost::bst_ulong const **out_shape,
xgboost::bst_ulong *out_dim, const float **out_result) { xgboost::bst_ulong *out_dim, const float **out_result) {
@ -79,8 +78,7 @@ XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(
handle, c_json_strs, c_json_config, p_m, out_shape, out_dim, out_result); handle, c_json_strs, c_json_config, p_m, out_shape, out_dim, out_result);
} }
// A hidden API as cache id is not being supported yet. XGB_DLL int XGBoosterPredictFromCudaArray(
XGB_DLL int XGBoosterPredictFromArrayInterface(
BoosterHandle handle, char const *c_json_strs, char const *c_json_config, BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
DMatrixHandle m, xgboost::bst_ulong const **out_shape, DMatrixHandle m, xgboost::bst_ulong const **out_shape,
xgboost::bst_ulong *out_dim, const float **out_result) { xgboost::bst_ulong *out_dim, const float **out_result) {

View File

@ -9,6 +9,7 @@
#include <vector> #include <vector>
#include "xgboost/logging.h" #include "xgboost/logging.h"
#include "xgboost/json.h"
#include "xgboost/learner.h" #include "xgboost/learner.h"
namespace xgboost { namespace xgboost {
@ -30,8 +31,8 @@ inline void CalcPredictShape(bool strict_shape, PredictionType type, size_t rows
std::vector<bst_ulong> *out_shape, std::vector<bst_ulong> *out_shape,
xgboost::bst_ulong *out_dim) { xgboost::bst_ulong *out_dim) {
auto &shape = *out_shape; auto &shape = *out_shape;
if ((type == PredictionType::kMargin || type == PredictionType::kValue) && if (type == PredictionType::kMargin && rows != 0) {
rows != 0) { // When kValue is used, softmax can change the chunksize.
CHECK_EQ(chunksize, groups); CHECK_EQ(chunksize, groups);
} }
@ -110,5 +111,35 @@ inline void CalcPredictShape(bool strict_shape, PredictionType type, size_t rows
std::accumulate(shape.cbegin(), shape.cend(), 1, std::multiplies<>{}), std::accumulate(shape.cbegin(), shape.cend(), 1, std::multiplies<>{}),
chunksize * rows); chunksize * rows);
} }
// Reverse the ntree_limit in old prediction API.
inline uint32_t GetIterationFromTreeLimit(uint32_t ntree_limit, Learner *learner) {
// On Python and R, `best_ntree_limit` is set to `best_iteration * num_parallel_tree`.
// To reverse it we just divide it by `num_parallel_tree`.
if (ntree_limit != 0) {
learner->Configure();
uint32_t num_parallel_tree = 0;
Json config{Object()};
learner->SaveConfig(&config);
auto const &booster =
get<String const>(config["learner"]["gradient_booster"]["name"]);
if (booster == "gblinear") {
num_parallel_tree = 0;
} else if (booster == "dart") {
num_parallel_tree = std::stoi(
get<String const>(config["learner"]["gradient_booster"]["gbtree"]
["gbtree_train_param"]["num_parallel_tree"]));
} else if (booster == "gbtree") {
num_parallel_tree = std::stoi(get<String const>(
(config["learner"]["gradient_booster"]["gbtree_train_param"]
["num_parallel_tree"])));
} else {
LOG(FATAL) << "Unknown booster:" << booster;
}
ntree_limit /= std::max(num_parallel_tree, 1u);
}
return ntree_limit;
}
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_C_API_C_API_UTILS_H_ #endif // XGBOOST_C_API_C_API_UTILS_H_

View File

@ -25,6 +25,7 @@
#include "common/config.h" #include "common/config.h"
#include "common/io.h" #include "common/io.h"
#include "common/version.h" #include "common/version.h"
#include "c_api/c_api_utils.h"
namespace xgboost { namespace xgboost {
enum CLITask { enum CLITask {
@ -58,6 +59,8 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
int dsplit; int dsplit;
/*!\brief limit number of trees in prediction */ /*!\brief limit number of trees in prediction */
int ntree_limit; int ntree_limit;
int iteration_begin;
int iteration_end;
/*!\brief whether to directly output margin value */ /*!\brief whether to directly output margin value */
bool pred_margin; bool pred_margin;
/*! \brief whether dump statistics along with model */ /*! \brief whether dump statistics along with model */
@ -109,7 +112,11 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
.add_enum("row", 2) .add_enum("row", 2)
.describe("Data split mode."); .describe("Data split mode.");
DMLC_DECLARE_FIELD(ntree_limit).set_default(0).set_lower_bound(0) DMLC_DECLARE_FIELD(ntree_limit).set_default(0).set_lower_bound(0)
.describe("Number of trees used for prediction, 0 means use all trees."); .describe("(Deprecated) Use iteration_begin/iteration_end instead.");
DMLC_DECLARE_FIELD(iteration_begin).set_default(0).set_lower_bound(0)
.describe("Begining of boosted tree iteration used for prediction.");
DMLC_DECLARE_FIELD(iteration_end).set_default(0).set_lower_bound(0)
.describe("End of boosted tree iteration used for prediction. 0 means all the trees.");
DMLC_DECLARE_FIELD(pred_margin).set_default(false) DMLC_DECLARE_FIELD(pred_margin).set_default(false)
.describe("Whether to predict margin value instead of probability."); .describe("Whether to predict margin value instead of probability.");
DMLC_DECLARE_FIELD(dump_stats).set_default(false) DMLC_DECLARE_FIELD(dump_stats).set_default(false)
@ -334,7 +341,13 @@ class CLI {
LOG(INFO) << "Start prediction..."; LOG(INFO) << "Start prediction...";
HostDeviceVector<bst_float> preds; HostDeviceVector<bst_float> preds;
learner_->Predict(dtest, param_.pred_margin, &preds, param_.ntree_limit); if (param_.ntree_limit != 0) {
param_.iteration_end = GetIterationFromTreeLimit(param_.ntree_limit, learner_.get());
LOG(WARNING) << "`ntree_limit` is deprecated, use `iteration_begin` and "
"`iteration_end` instead.";
}
learner_->Predict(dtest, param_.pred_margin, &preds, param_.iteration_begin,
param_.iteration_end);
LOG(CONSOLE) << "Writing prediction to " << param_.name_pred; LOG(CONSOLE) << "Writing prediction to " << param_.name_pred;
std::unique_ptr<dmlc::Stream> fo( std::unique_ptr<dmlc::Stream> fo(

View File

@ -47,6 +47,12 @@ struct GBLinearTrainParam : public XGBoostParameter<GBLinearTrainParam> {
.describe("Maximum rows per batch."); .describe("Maximum rows per batch.");
} }
}; };
void LinearCheckLayer(unsigned layer_begin, unsigned layer_end) {
CHECK_EQ(layer_begin, 0) << "Linear booster does not support prediction range.";
CHECK_EQ(layer_end, 0) << "Linear booster does not support prediction range.";
}
/*! /*!
* \brief gradient boosted linear model * \brief gradient boosted linear model
*/ */
@ -130,20 +136,19 @@ class GBLinear : public GradientBooster {
monitor_.Stop("DoBoost"); monitor_.Stop("DoBoost");
} }
void PredictBatch(DMatrix *p_fmat, void PredictBatch(DMatrix *p_fmat, PredictionCacheEntry *predts,
PredictionCacheEntry *predts, bool training, unsigned layer_begin, unsigned layer_end) override {
bool, unsigned ntree_limit) override {
monitor_.Start("PredictBatch"); monitor_.Start("PredictBatch");
LinearCheckLayer(layer_begin, layer_end);
auto* out_preds = &predts->predictions; auto* out_preds = &predts->predictions;
CHECK_EQ(ntree_limit, 0U)
<< "GBLinear::Predict ntrees is only valid for gbtree predictor";
this->PredictBatchInternal(p_fmat, &out_preds->HostVector()); this->PredictBatchInternal(p_fmat, &out_preds->HostVector());
monitor_.Stop("PredictBatch"); monitor_.Stop("PredictBatch");
} }
// add base margin // add base margin
void PredictInstance(const SparsePage::Inst &inst, void PredictInstance(const SparsePage::Inst &inst,
std::vector<bst_float> *out_preds, std::vector<bst_float> *out_preds,
unsigned) override { unsigned layer_begin, unsigned layer_end) override {
LinearCheckLayer(layer_begin, layer_end);
const int ngroup = model_.learner_model_param->num_output_group; const int ngroup = model_.learner_model_param->num_output_group;
for (int gid = 0; gid < ngroup; ++gid) { for (int gid = 0; gid < ngroup; ++gid) {
this->Pred(inst, dmlc::BeginPtr(*out_preds), gid, this->Pred(inst, dmlc::BeginPtr(*out_preds), gid,
@ -151,16 +156,15 @@ class GBLinear : public GradientBooster {
} }
} }
void PredictLeaf(DMatrix *, HostDeviceVector<bst_float> *, unsigned) override { void PredictLeaf(DMatrix *, HostDeviceVector<bst_float> *, unsigned, unsigned) override {
LOG(FATAL) << "gblinear does not support prediction of leaf index"; LOG(FATAL) << "gblinear does not support prediction of leaf index";
} }
void PredictContribution(DMatrix* p_fmat, void PredictContribution(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_contribs, HostDeviceVector<bst_float>* out_contribs,
unsigned ntree_limit, bool, int, unsigned) override { unsigned layer_begin, unsigned layer_end, bool, int, unsigned) override {
model_.LazyInitModel(); model_.LazyInitModel();
CHECK_EQ(ntree_limit, 0U) LinearCheckLayer(layer_begin, layer_end);
<< "GBLinear::PredictContribution: ntrees is only valid for gbtree predictor";
const auto& base_margin = p_fmat->Info().base_margin_.ConstHostVector(); const auto& base_margin = p_fmat->Info().base_margin_.ConstHostVector();
const int ngroup = model_.learner_model_param->num_output_group; const int ngroup = model_.learner_model_param->num_output_group;
const size_t ncolumns = model_.learner_model_param->num_feature + 1; const size_t ncolumns = model_.learner_model_param->num_feature + 1;
@ -197,7 +201,8 @@ class GBLinear : public GradientBooster {
void PredictInteractionContributions(DMatrix* p_fmat, void PredictInteractionContributions(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_contribs, HostDeviceVector<bst_float>* out_contribs,
unsigned, bool) override { unsigned layer_begin, unsigned layer_end, bool) override {
LinearCheckLayer(layer_begin, layer_end);
std::vector<bst_float>& contribs = out_contribs->HostVector(); std::vector<bst_float>& contribs = out_contribs->HostVector();
// linear models have no interaction effects // linear models have no interaction effects

View File

@ -414,7 +414,7 @@ void GBTree::Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
auto layer_trees = this->LayerTrees(); auto layer_trees = this->LayerTrees();
layer_end = layer_end == 0 ? model_.trees.size() / layer_trees : layer_end; layer_end = layer_end == 0 ? model_.trees.size() / layer_trees : layer_end;
CHECK_GE(layer_end, layer_begin); CHECK_GT(layer_end, layer_begin);
CHECK_GE(step, 1); CHECK_GE(step, 1);
int32_t n_layers = (layer_end - layer_begin) / step; int32_t n_layers = (layer_end - layer_begin) / step;
std::vector<std::unique_ptr<RegTree>> &out_trees = out_model.trees; std::vector<std::unique_ptr<RegTree>> &out_trees = out_model.trees;
@ -438,10 +438,35 @@ void GBTree::Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
void GBTree::PredictBatch(DMatrix* p_fmat, void GBTree::PredictBatch(DMatrix* p_fmat,
PredictionCacheEntry* out_preds, PredictionCacheEntry* out_preds,
bool, bool,
unsigned ntree_limit) { unsigned layer_begin,
unsigned layer_end) {
CHECK(configured_); CHECK(configured_);
if (layer_end == 0) {
layer_end = this->BoostedRounds();
}
if (layer_begin != 0 || layer_end < out_preds->version) {
// cache is dropped.
out_preds->version = 0;
}
bool reset = false;
if (layer_begin == 0) {
layer_begin = out_preds->version;
} else {
// When begin layer is not 0, the cache is not useful.
reset = true;
}
uint32_t tree_begin, tree_end;
std::tie(tree_begin, tree_end) =
detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
GetPredictor(&out_preds->predictions, p_fmat) GetPredictor(&out_preds->predictions, p_fmat)
->PredictBatch(p_fmat, out_preds, model_, 0, ntree_limit); ->PredictBatch(p_fmat, out_preds, model_, tree_begin, tree_end);
if (reset) {
out_preds->version = 0;
} else {
uint32_t delta = layer_end - out_preds->version;
out_preds->Update(delta);
}
} }
std::unique_ptr<Predictor> const & std::unique_ptr<Predictor> const &
@ -603,13 +628,14 @@ class Dart : public GBTree {
void PredictBatch(DMatrix* p_fmat, void PredictBatch(DMatrix* p_fmat,
PredictionCacheEntry* p_out_preds, PredictionCacheEntry* p_out_preds,
bool training, bool training,
unsigned ntree_limit) override { unsigned layer_begin,
unsigned layer_end) override {
DropTrees(training); DropTrees(training);
int num_group = model_.learner_model_param->num_output_group; int num_group = model_.learner_model_param->num_output_group;
ntree_limit *= num_group; uint32_t tree_begin, tree_end;
if (ntree_limit == 0 || ntree_limit > model_.trees.size()) { std::tie(tree_begin, tree_end) =
ntree_limit = static_cast<unsigned>(model_.trees.size()); detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
}
size_t n = num_group * p_fmat->Info().num_row_; size_t n = num_group * p_fmat->Info().num_row_;
const auto &base_margin = p_fmat->Info().base_margin_.ConstHostVector(); const auto &base_margin = p_fmat->Info().base_margin_.ConstHostVector();
auto& out_preds = p_out_preds->predictions.HostVector(); auto& out_preds = p_out_preds->predictions.HostVector();
@ -623,26 +649,24 @@ class Dart : public GBTree {
} }
const int nthread = omp_get_max_threads(); const int nthread = omp_get_max_threads();
InitThreadTemp(nthread); InitThreadTemp(nthread);
PredLoopSpecalize(p_fmat, &out_preds, num_group, 0, ntree_limit); PredLoopSpecalize(p_fmat, &out_preds, num_group, tree_begin, tree_end);
} }
void PredictInstance(const SparsePage::Inst &inst, void PredictInstance(const SparsePage::Inst &inst,
std::vector<bst_float> *out_preds, std::vector<bst_float> *out_preds,
unsigned ntree_limit) override { unsigned layer_begin, unsigned layer_end) override {
DropTrees(false); DropTrees(false);
if (thread_temp_.size() == 0) { if (thread_temp_.size() == 0) {
thread_temp_.resize(1, RegTree::FVec()); thread_temp_.resize(1, RegTree::FVec());
thread_temp_[0].Init(model_.learner_model_param->num_feature); thread_temp_[0].Init(model_.learner_model_param->num_feature);
} }
out_preds->resize(model_.learner_model_param->num_output_group); out_preds->resize(model_.learner_model_param->num_output_group);
ntree_limit *= model_.learner_model_param->num_output_group; uint32_t tree_begin, tree_end;
if (ntree_limit == 0 || ntree_limit > model_.trees.size()) { std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
ntree_limit = static_cast<unsigned>(model_.trees.size());
}
// loop over output groups // loop over output groups
for (uint32_t gid = 0; gid < model_.learner_model_param->num_output_group; ++gid) { for (uint32_t gid = 0; gid < model_.learner_model_param->num_output_group; ++gid) {
(*out_preds)[gid] = (*out_preds)[gid] =
PredValue(inst, gid, &thread_temp_[0], 0, ntree_limit) + PredValue(inst, gid, &thread_temp_[0], 0, tree_end) +
model_.learner_model_param->base_score; model_.learner_model_param->base_score;
} }
} }
@ -653,22 +677,25 @@ class Dart : public GBTree {
void PredictContribution(DMatrix* p_fmat, void PredictContribution(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_contribs, HostDeviceVector<bst_float>* out_contribs,
unsigned ntree_limit, bool approximate, int, unsigned layer_begin, unsigned layer_end, bool approximate, int,
unsigned) override { unsigned) override {
CHECK(configured_); CHECK(configured_);
uint32_t tree_begin, tree_end;
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
cpu_predictor_->PredictContribution(p_fmat, out_contribs, model_, cpu_predictor_->PredictContribution(p_fmat, out_contribs, model_,
ntree_limit, &weight_drop_, approximate); tree_end, &weight_drop_, approximate);
} }
void PredictInteractionContributions(DMatrix* p_fmat, void PredictInteractionContributions(
HostDeviceVector<bst_float>* out_contribs, DMatrix *p_fmat, HostDeviceVector<bst_float> *out_contribs,
unsigned ntree_limit, bool approximate) override { unsigned layer_begin, unsigned layer_end, bool approximate) override {
CHECK(configured_); CHECK(configured_);
uint32_t tree_begin, tree_end;
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
cpu_predictor_->PredictInteractionContributions(p_fmat, out_contribs, model_, cpu_predictor_->PredictInteractionContributions(p_fmat, out_contribs, model_,
ntree_limit, &weight_drop_, approximate); tree_end, &weight_drop_, approximate);
} }
protected: protected:
inline void PredLoopSpecalize( inline void PredLoopSpecalize(
DMatrix* p_fmat, DMatrix* p_fmat,

View File

@ -164,7 +164,9 @@ inline std::pair<uint32_t, uint32_t> LayerToTree(gbm::GBTreeModel const &model,
if (tree_end == 0) { if (tree_end == 0) {
tree_end = static_cast<uint32_t>(model.trees.size()); tree_end = static_cast<uint32_t>(model.trees.size());
} }
CHECK_LT(tree_begin, tree_end); if (model.trees.size() != 0) {
CHECK_LE(tree_begin, tree_end);
}
return {tree_begin, tree_end}; return {tree_begin, tree_end};
} }
@ -260,10 +262,8 @@ class GBTree : public GradientBooster {
return model_.trees.size() / this->LayerTrees(); return model_.trees.size() / this->LayerTrees();
} }
void PredictBatch(DMatrix* p_fmat, void PredictBatch(DMatrix *p_fmat, PredictionCacheEntry *out_preds,
PredictionCacheEntry* out_preds, bool training, unsigned layer_begin, unsigned layer_end) override;
bool training,
unsigned ntree_limit) override;
void InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m, void InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
float missing, PredictionCacheEntry *out_preds, float missing, PredictionCacheEntry *out_preds,
@ -297,33 +297,49 @@ class GBTree : public GradientBooster {
void PredictInstance(const SparsePage::Inst& inst, void PredictInstance(const SparsePage::Inst& inst,
std::vector<bst_float>* out_preds, std::vector<bst_float>* out_preds,
unsigned ntree_limit) override { uint32_t layer_begin, uint32_t layer_end) override {
CHECK(configured_); CHECK(configured_);
uint32_t tree_begin, tree_end;
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
cpu_predictor_->PredictInstance(inst, out_preds, model_, cpu_predictor_->PredictInstance(inst, out_preds, model_,
ntree_limit); tree_end);
} }
void PredictLeaf(DMatrix* p_fmat, void PredictLeaf(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_preds, HostDeviceVector<bst_float>* out_preds,
unsigned ntree_limit) override { uint32_t layer_begin, uint32_t layer_end) override {
this->GetPredictor()->PredictLeaf(p_fmat, out_preds, model_, ntree_limit); uint32_t tree_begin, tree_end;
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
CHECK_EQ(tree_begin, 0) << "Predict leaf supports only iteration end: (0, "
"n_iteration), use model slicing instead.";
this->GetPredictor()->PredictLeaf(p_fmat, out_preds, model_, tree_end);
} }
void PredictContribution(DMatrix* p_fmat, void PredictContribution(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_contribs, HostDeviceVector<bst_float>* out_contribs,
unsigned ntree_limit, bool approximate, uint32_t layer_begin, uint32_t layer_end, bool approximate,
int, unsigned) override { int, unsigned) override {
CHECK(configured_); CHECK(configured_);
uint32_t tree_begin, tree_end;
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
CHECK_EQ(tree_begin, 0)
<< "Predict contribution supports only iteration end: (0, "
"n_iteration), using model slicing instead.";
this->GetPredictor()->PredictContribution( this->GetPredictor()->PredictContribution(
p_fmat, out_contribs, model_, ntree_limit, nullptr, approximate); p_fmat, out_contribs, model_, tree_end, nullptr, approximate);
} }
void PredictInteractionContributions(DMatrix* p_fmat, void PredictInteractionContributions(
HostDeviceVector<bst_float>* out_contribs, DMatrix *p_fmat, HostDeviceVector<bst_float> *out_contribs,
unsigned ntree_limit, bool approximate) override { uint32_t layer_begin, uint32_t layer_end, bool approximate) override {
CHECK(configured_); CHECK(configured_);
this->GetPredictor()->PredictInteractionContributions(p_fmat, out_contribs, model_, uint32_t tree_begin, tree_end;
ntree_limit, nullptr, approximate); std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
CHECK_EQ(tree_begin, 0)
<< "Predict interaction contribution supports only iteration end: (0, "
"n_iteration), using model slicing instead.";
this->GetPredictor()->PredictInteractionContributions(
p_fmat, out_contribs, model_, tree_end, nullptr, approximate);
} }
std::vector<std::string> DumpModel(const FeatureMap& fmap, std::vector<std::string> DumpModel(const FeatureMap& fmap,

View File

@ -22,6 +22,7 @@
#include "dmlc/any.h" #include "dmlc/any.h"
#include "xgboost/base.h" #include "xgboost/base.h"
#include "xgboost/c_api.h"
#include "xgboost/data.h" #include "xgboost/data.h"
#include "xgboost/model.h" #include "xgboost/model.h"
#include "xgboost/predictor.h" #include "xgboost/predictor.h"
@ -996,7 +997,7 @@ class LearnerImpl : public LearnerIO {
auto& predt = local_cache->Cache(train, generic_parameters_.gpu_id); auto& predt = local_cache->Cache(train, generic_parameters_.gpu_id);
monitor_.Start("PredictRaw"); monitor_.Start("PredictRaw");
this->PredictRaw(train.get(), &predt, true); this->PredictRaw(train.get(), &predt, true, 0, 0);
TrainingObserver::Instance().Observe(predt.predictions, "Predictions"); TrainingObserver::Instance().Observe(predt.predictions, "Predictions");
monitor_.Stop("PredictRaw"); monitor_.Stop("PredictRaw");
@ -1057,7 +1058,7 @@ class LearnerImpl : public LearnerIO {
std::shared_ptr<DMatrix> m = data_sets[i]; std::shared_ptr<DMatrix> m = data_sets[i];
auto &predt = local_cache->Cache(m, generic_parameters_.gpu_id); auto &predt = local_cache->Cache(m, generic_parameters_.gpu_id);
this->ValidateDMatrix(m.get(), false); this->ValidateDMatrix(m.get(), false);
this->PredictRaw(m.get(), &predt, false); this->PredictRaw(m.get(), &predt, false, 0, 0);
auto &out = output_predictions_.Cache(m, generic_parameters_.gpu_id).predictions; auto &out = output_predictions_.Cache(m, generic_parameters_.gpu_id).predictions;
out.Resize(predt.predictions.Size()); out.Resize(predt.predictions.Size());
@ -1075,8 +1076,8 @@ class LearnerImpl : public LearnerIO {
} }
void Predict(std::shared_ptr<DMatrix> data, bool output_margin, void Predict(std::shared_ptr<DMatrix> data, bool output_margin,
HostDeviceVector<bst_float>* out_preds, unsigned ntree_limit, HostDeviceVector<bst_float> *out_preds, unsigned layer_begin,
bool training, unsigned layer_end, bool training,
bool pred_leaf, bool pred_contribs, bool approx_contribs, bool pred_leaf, bool pred_contribs, bool approx_contribs,
bool pred_interactions) override { bool pred_interactions) override {
int multiple_predictions = static_cast<int>(pred_leaf) + int multiple_predictions = static_cast<int>(pred_leaf) +
@ -1085,16 +1086,16 @@ class LearnerImpl : public LearnerIO {
this->Configure(); this->Configure();
CHECK_LE(multiple_predictions, 1) << "Perform one kind of prediction at a time."; CHECK_LE(multiple_predictions, 1) << "Perform one kind of prediction at a time.";
if (pred_contribs) { if (pred_contribs) {
gbm_->PredictContribution(data.get(), out_preds, ntree_limit, approx_contribs); gbm_->PredictContribution(data.get(), out_preds, layer_begin, layer_end, approx_contribs);
} else if (pred_interactions) { } else if (pred_interactions) {
gbm_->PredictInteractionContributions(data.get(), out_preds, ntree_limit, gbm_->PredictInteractionContributions(data.get(), out_preds, layer_begin, layer_end,
approx_contribs); approx_contribs);
} else if (pred_leaf) { } else if (pred_leaf) {
gbm_->PredictLeaf(data.get(), out_preds, ntree_limit); gbm_->PredictLeaf(data.get(), out_preds, layer_begin, layer_end);
} else { } else {
auto local_cache = this->GetPredictionCache(); auto local_cache = this->GetPredictionCache();
auto& prediction = local_cache->Cache(data, generic_parameters_.gpu_id); auto& prediction = local_cache->Cache(data, generic_parameters_.gpu_id);
this->PredictRaw(data.get(), &prediction, training, ntree_limit); this->PredictRaw(data.get(), &prediction, training, layer_begin, layer_end);
// Copy the prediction cache to output prediction. out_preds comes from C API // Copy the prediction cache to output prediction. out_preds comes from C API
out_preds->SetDevice(generic_parameters_.gpu_id); out_preds->SetDevice(generic_parameters_.gpu_id);
out_preds->Resize(prediction.predictions.Size()); out_preds->Resize(prediction.predictions.Size());
@ -1151,12 +1152,11 @@ class LearnerImpl : public LearnerIO {
* predictor, when it equals 0, this means we are using all the trees * predictor, when it equals 0, this means we are using all the trees
* \param training allow dropout when the DART booster is being used * \param training allow dropout when the DART booster is being used
*/ */
void PredictRaw(DMatrix* data, PredictionCacheEntry* out_preds, void PredictRaw(DMatrix *data, PredictionCacheEntry *out_preds, bool training,
bool training, unsigned layer_begin, unsigned layer_end) const {
unsigned ntree_limit = 0) const {
CHECK(gbm_ != nullptr) << "Predict must happen after Load or configuration"; CHECK(gbm_ != nullptr) << "Predict must happen after Load or configuration";
this->ValidateDMatrix(data, false); this->ValidateDMatrix(data, false);
gbm_->PredictBatch(data, out_preds, training, ntree_limit); gbm_->PredictBatch(data, out_preds, training, layer_begin, layer_end);
} }
void ValidateDMatrix(DMatrix* p_fmat, bool is_training) const { void ValidateDMatrix(DMatrix* p_fmat, bool is_training) const {

View File

@ -234,56 +234,28 @@ class CPUPredictor : public Predictor {
public: public:
explicit CPUPredictor(GenericParameter const* generic_param) : explicit CPUPredictor(GenericParameter const* generic_param) :
Predictor::Predictor{generic_param} {} Predictor::Predictor{generic_param} {}
// ntree_limit is a very problematic parameter, as it's ambiguous in the context of void PredictBatch(DMatrix *dmat, PredictionCacheEntry *predts,
// multi-output and forest. Same problem exists for tree_begin const gbm::GBTreeModel &model, uint32_t tree_begin,
void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts, uint32_t tree_end = 0) const override {
const gbm::GBTreeModel& model, int tree_begin,
uint32_t const ntree_limit = 0) const override {
// tree_begin is not used, right now we just enforce it to be 0.
CHECK_EQ(tree_begin, 0);
auto* out_preds = &predts->predictions; auto* out_preds = &predts->predictions;
CHECK_GE(predts->version, tree_begin);
if (out_preds->Size() == 0 && dmat->Info().num_row_ != 0) { if (out_preds->Size() == 0 && dmat->Info().num_row_ != 0) {
CHECK_EQ(predts->version, 0); CHECK_EQ(predts->version, 0);
} }
// This is actually already handled in gbm, but large amount of tests rely on the
// behaviour.
if (tree_end == 0) {
tree_end = model.trees.size();
}
if (predts->version == 0) { if (predts->version == 0) {
// out_preds->Size() can be non-zero as it's initialized here before any tree is // out_preds->Size() can be non-zero as it's initialized here before any tree is
// built at the 0^th iterator. // built at the 0^th iterator.
this->InitOutPredictions(dmat->Info(), out_preds, model); this->InitOutPredictions(dmat->Info(), out_preds, model);
} }
if (tree_end - tree_begin == 0) {
uint32_t const output_groups = model.learner_model_param->num_output_group; return;
CHECK_NE(output_groups, 0);
// Right now we just assume ntree_limit provided by users means number of tree layers
// in the context of multi-output model
uint32_t real_ntree_limit = ntree_limit * output_groups;
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
real_ntree_limit = static_cast<uint32_t>(model.trees.size());
} }
this->PredictDMatrix(dmat, &out_preds->HostVector(), model, tree_begin,
uint32_t const end_version = (tree_begin + real_ntree_limit) / output_groups; tree_end);
// When users have provided ntree_limit, end_version can be lesser, cache is violated
if (predts->version > end_version) {
CHECK_NE(ntree_limit, 0);
this->InitOutPredictions(dmat->Info(), out_preds, model);
predts->version = 0;
}
uint32_t const beg_version = predts->version;
CHECK_LE(beg_version, end_version);
if (beg_version < end_version) {
this->PredictDMatrix(dmat, &out_preds->HostVector(), model,
beg_version * output_groups,
end_version * output_groups);
}
// delta means {size of forest} * {number of newly accumulated layers}
uint32_t delta = end_version - beg_version;
CHECK_LE(delta, model.trees.size());
predts->Update(delta);
CHECK(out_preds->Size() == output_groups * dmat->Info().num_row_ ||
out_preds->Size() == dmat->Info().num_row_);
} }
template <typename Adapter> template <typename Adapter>
@ -362,7 +334,6 @@ class CPUPredictor : public Predictor {
InitThreadTemp(nthread, model.learner_model_param->num_feature, &feat_vecs); InitThreadTemp(nthread, model.learner_model_param->num_feature, &feat_vecs);
const MetaInfo& info = p_fmat->Info(); const MetaInfo& info = p_fmat->Info();
// number of valid trees // number of valid trees
ntree_limit *= model.learner_model_param->num_output_group;
if (ntree_limit == 0 || ntree_limit > model.trees.size()) { if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
ntree_limit = static_cast<unsigned>(model.trees.size()); ntree_limit = static_cast<unsigned>(model.trees.size());
} }
@ -398,7 +369,6 @@ class CPUPredictor : public Predictor {
InitThreadTemp(nthread, model.learner_model_param->num_feature, &feat_vecs); InitThreadTemp(nthread, model.learner_model_param->num_feature, &feat_vecs);
const MetaInfo& info = p_fmat->Info(); const MetaInfo& info = p_fmat->Info();
// number of valid trees // number of valid trees
ntree_limit *= model.learner_model_param->num_output_group;
if (ntree_limit == 0 || ntree_limit > model.trees.size()) { if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
ntree_limit = static_cast<unsigned>(model.trees.size()); ntree_limit = static_cast<unsigned>(model.trees.size());
} }

View File

@ -536,6 +536,7 @@ class GPUPredictor : public xgboost::Predictor {
const uint32_t BLOCK_THREADS = 256; const uint32_t BLOCK_THREADS = 256;
size_t num_rows = batch.n_rows; size_t num_rows = batch.n_rows;
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS)); auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS));
DeviceModel d_model;
bool use_shared = false; bool use_shared = false;
size_t entry_start = 0; size_t entry_start = 0;
@ -593,54 +594,27 @@ class GPUPredictor : public xgboost::Predictor {
} }
void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts, void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts,
const gbm::GBTreeModel& model, int tree_begin, const gbm::GBTreeModel& model, uint32_t tree_begin,
unsigned ntree_limit = 0) const override { uint32_t tree_end = 0) const override {
// This function is duplicated with CPU predictor PredictBatch, see comments in there.
// FIXME(trivialfis): Remove the duplication.
int device = generic_param_->gpu_id; int device = generic_param_->gpu_id;
CHECK_GE(device, 0) << "Set `gpu_id' to positive value for processing GPU data."; CHECK_GE(device, 0) << "Set `gpu_id' to positive value for processing GPU data.";
ConfigureDevice(device);
CHECK_EQ(tree_begin, 0);
auto* out_preds = &predts->predictions; auto* out_preds = &predts->predictions;
CHECK_GE(predts->version, tree_begin);
if (out_preds->Size() == 0 && dmat->Info().num_row_ != 0) { if (out_preds->Size() == 0 && dmat->Info().num_row_ != 0) {
CHECK_EQ(predts->version, 0); CHECK_EQ(predts->version, 0);
} }
if (tree_end == 0) {
tree_end = model.trees.size();
}
if (predts->version == 0) { if (predts->version == 0) {
// out_preds->Size() can be non-zero as it's initialized here before any tree is
// built at the 0^th iterator.
this->InitOutPredictions(dmat->Info(), out_preds, model); this->InitOutPredictions(dmat->Info(), out_preds, model);
} }
if (tree_end - tree_begin == 0) {
uint32_t const output_groups = model.learner_model_param->num_output_group; return;
CHECK_NE(output_groups, 0);
uint32_t real_ntree_limit = ntree_limit * output_groups;
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
real_ntree_limit = static_cast<uint32_t>(model.trees.size());
} }
this->DevicePredictInternal(dmat, out_preds, model, tree_begin, tree_end);
uint32_t const end_version = (tree_begin + real_ntree_limit) / output_groups;
if (predts->version > end_version) {
CHECK_NE(ntree_limit, 0);
this->InitOutPredictions(dmat->Info(), out_preds, model);
predts->version = 0;
}
uint32_t const beg_version = predts->version;
CHECK_LE(beg_version, end_version);
if (beg_version < end_version) {
this->DevicePredictInternal(dmat, out_preds, model,
beg_version * output_groups,
end_version * output_groups);
}
uint32_t delta = end_version - beg_version;
CHECK_LE(delta, model.trees.size());
predts->Update(delta);
CHECK(out_preds->Size() == output_groups * dmat->Info().num_row_ ||
out_preds->Size() == dmat->Info().num_row_);
} }
template <typename Adapter, typename Loader> template <typename Adapter, typename Loader>
@ -648,15 +622,12 @@ class GPUPredictor : public xgboost::Predictor {
const gbm::GBTreeModel &model, float, const gbm::GBTreeModel &model, float,
PredictionCacheEntry *out_preds, PredictionCacheEntry *out_preds,
uint32_t tree_begin, uint32_t tree_end) const { uint32_t tree_begin, uint32_t tree_end) const {
auto max_shared_memory_bytes = dh::MaxSharedMemory(this->generic_param_->gpu_id);
uint32_t const output_groups = model.learner_model_param->num_output_group; uint32_t const output_groups = model.learner_model_param->num_output_group;
DeviceModel d_model;
d_model.Init(model, tree_begin, tree_end, this->generic_param_->gpu_id);
auto m = dmlc::get<std::shared_ptr<Adapter>>(x); auto m = dmlc::get<std::shared_ptr<Adapter>>(x);
CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature) CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature)
<< "Number of columns in data must equal to trained model."; << "Number of columns in data must equal to trained model.";
CHECK_EQ(this->generic_param_->gpu_id, m->DeviceIdx()) CHECK_EQ(dh::CurrentDevice(), m->DeviceIdx())
<< "XGBoost is running on device: " << this->generic_param_->gpu_id << ", " << "XGBoost is running on device: " << this->generic_param_->gpu_id << ", "
<< "but data is on: " << m->DeviceIdx(); << "but data is on: " << m->DeviceIdx();
if (p_m) { if (p_m) {
@ -667,12 +638,17 @@ class GPUPredictor : public xgboost::Predictor {
info.num_row_ = m->NumRows(); info.num_row_ = m->NumRows();
this->InitOutPredictions(info, &(out_preds->predictions), model); this->InitOutPredictions(info, &(out_preds->predictions), model);
} }
out_preds->predictions.SetDevice(m->DeviceIdx());
const uint32_t BLOCK_THREADS = 128; const uint32_t BLOCK_THREADS = 128;
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(m->NumRows(), BLOCK_THREADS)); auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(m->NumRows(), BLOCK_THREADS));
auto max_shared_memory_bytes = dh::MaxSharedMemory(m->DeviceIdx());
size_t shared_memory_bytes = size_t shared_memory_bytes =
SharedMemoryBytes<BLOCK_THREADS>(m->NumColumns(), max_shared_memory_bytes); SharedMemoryBytes<BLOCK_THREADS>(m->NumColumns(), max_shared_memory_bytes);
DeviceModel d_model;
d_model.Init(model, tree_begin, tree_end, m->DeviceIdx());
bool use_shared = shared_memory_bytes != 0; bool use_shared = shared_memory_bytes != 0;
size_t entry_start = 0; size_t entry_start = 0;
@ -707,20 +683,17 @@ class GPUPredictor : public xgboost::Predictor {
void PredictContribution(DMatrix* p_fmat, void PredictContribution(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_contribs, HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, unsigned ntree_limit, const gbm::GBTreeModel& model, unsigned tree_end,
std::vector<bst_float>*, std::vector<bst_float>*,
bool approximate, int, bool approximate, int,
unsigned) const override { unsigned) const override {
if (approximate) { if (approximate) {
LOG(FATAL) << "Approximated contribution is not implemented in GPU Predictor."; LOG(FATAL) << "Approximated contribution is not implemented in GPU Predictor.";
} }
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id)); dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
out_contribs->SetDevice(generic_param_->gpu_id); out_contribs->SetDevice(generic_param_->gpu_id);
uint32_t real_ntree_limit = if (tree_end == 0 || tree_end > model.trees.size()) {
ntree_limit * model.learner_model_param->num_output_group; tree_end = static_cast<uint32_t>(model.trees.size());
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
real_ntree_limit = static_cast<uint32_t>(model.trees.size());
} }
const int ngroup = model.learner_model_param->num_output_group; const int ngroup = model.learner_model_param->num_output_group;
@ -734,8 +707,7 @@ class GPUPredictor : public xgboost::Predictor {
auto phis = out_contribs->DeviceSpan(); auto phis = out_contribs->DeviceSpan();
dh::device_vector<gpu_treeshap::PathElement> device_paths; dh::device_vector<gpu_treeshap::PathElement> device_paths;
ExtractPaths(&device_paths, model, real_ntree_limit, ExtractPaths(&device_paths, model, tree_end, generic_param_->gpu_id);
generic_param_->gpu_id);
for (auto& batch : p_fmat->GetBatches<SparsePage>()) { for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(generic_param_->gpu_id); batch.data.SetDevice(generic_param_->gpu_id);
batch.offset.SetDevice(generic_param_->gpu_id); batch.offset.SetDevice(generic_param_->gpu_id);
@ -761,20 +733,17 @@ class GPUPredictor : public xgboost::Predictor {
void PredictInteractionContributions(DMatrix* p_fmat, void PredictInteractionContributions(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_contribs, HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, const gbm::GBTreeModel& model,
unsigned ntree_limit, unsigned tree_end,
std::vector<bst_float>*, std::vector<bst_float>*,
bool approximate) const override { bool approximate) const override {
if (approximate) { if (approximate) {
LOG(FATAL) << "[Internal error]: " << __func__ LOG(FATAL) << "[Internal error]: " << __func__
<< " approximate is not implemented in GPU Predictor."; << " approximate is not implemented in GPU Predictor.";
} }
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id)); dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
out_contribs->SetDevice(generic_param_->gpu_id); out_contribs->SetDevice(generic_param_->gpu_id);
uint32_t real_ntree_limit = if (tree_end == 0 || tree_end > model.trees.size()) {
ntree_limit * model.learner_model_param->num_output_group; tree_end = static_cast<uint32_t>(model.trees.size());
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
real_ntree_limit = static_cast<uint32_t>(model.trees.size());
} }
const int ngroup = model.learner_model_param->num_output_group; const int ngroup = model.learner_model_param->num_output_group;
@ -789,8 +758,7 @@ class GPUPredictor : public xgboost::Predictor {
auto phis = out_contribs->DeviceSpan(); auto phis = out_contribs->DeviceSpan();
dh::device_vector<gpu_treeshap::PathElement> device_paths; dh::device_vector<gpu_treeshap::PathElement> device_paths;
ExtractPaths(&device_paths, model, real_ntree_limit, ExtractPaths(&device_paths, model, tree_end, generic_param_->gpu_id);
generic_param_->gpu_id);
for (auto& batch : p_fmat->GetBatches<SparsePage>()) { for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(generic_param_->gpu_id); batch.data.SetDevice(generic_param_->gpu_id);
batch.offset.SetDevice(generic_param_->gpu_id); batch.offset.SetDevice(generic_param_->gpu_id);
@ -841,29 +809,28 @@ class GPUPredictor : public xgboost::Predictor {
<< " is not implemented in GPU Predictor."; << " is not implemented in GPU Predictor.";
} }
void PredictLeaf(DMatrix* p_fmat, HostDeviceVector<bst_float>* predictions, void PredictLeaf(DMatrix *p_fmat, HostDeviceVector<bst_float> *predictions,
const gbm::GBTreeModel& model, const gbm::GBTreeModel &model,
unsigned ntree_limit) const override { unsigned tree_end) const override {
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id)); dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
auto max_shared_memory_bytes = ConfigureDevice(generic_param_->gpu_id); auto max_shared_memory_bytes = ConfigureDevice(generic_param_->gpu_id);
const MetaInfo& info = p_fmat->Info(); const MetaInfo& info = p_fmat->Info();
constexpr uint32_t kBlockThreads = 128; constexpr uint32_t kBlockThreads = 128;
size_t shared_memory_bytes = size_t shared_memory_bytes = SharedMemoryBytes<kBlockThreads>(
SharedMemoryBytes<kBlockThreads>(info.num_col_, max_shared_memory_bytes); info.num_col_, max_shared_memory_bytes);
bool use_shared = shared_memory_bytes != 0; bool use_shared = shared_memory_bytes != 0;
bst_feature_t num_features = info.num_col_; bst_feature_t num_features = info.num_col_;
bst_row_t num_rows = info.num_row_; bst_row_t num_rows = info.num_row_;
size_t entry_start = 0; size_t entry_start = 0;
uint32_t real_ntree_limit = ntree_limit * model.learner_model_param->num_output_group; if (tree_end == 0 || tree_end > model.trees.size()) {
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) { tree_end = static_cast<uint32_t>(model.trees.size());
real_ntree_limit = static_cast<uint32_t>(model.trees.size());
} }
predictions->SetDevice(generic_param_->gpu_id); predictions->SetDevice(generic_param_->gpu_id);
predictions->Resize(num_rows * real_ntree_limit); predictions->Resize(num_rows * tree_end);
DeviceModel d_model; DeviceModel d_model;
d_model.Init(model, 0, real_ntree_limit, this->generic_param_->gpu_id); d_model.Init(model, 0, tree_end, this->generic_param_->gpu_id);
if (p_fmat->PageExists<SparsePage>()) { if (p_fmat->PageExists<SparsePage>()) {
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) { for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {

View File

@ -34,6 +34,7 @@ dependencies:
- llvmlite - llvmlite
- pip: - pip:
- shap - shap
- ipython # required by shap at import time.
- guzzle_sphinx_theme - guzzle_sphinx_theme
- datatable - datatable
- modin[all] - modin[all]

View File

@ -51,6 +51,53 @@ TEST(GBTree, SelectTreeMethod) {
#endif // XGBOOST_USE_CUDA #endif // XGBOOST_USE_CUDA
} }
TEST(GBTree, PredictionCache) {
size_t constexpr kRows = 100, kCols = 10;
GenericParameter generic_param;
generic_param.UpdateAllowUnknown(Args{});
LearnerModelParam mparam;
mparam.base_score = 0.5;
mparam.num_feature = kCols;
mparam.num_output_group = 1;
std::unique_ptr<GradientBooster> p_gbm {
GradientBooster::Create("gbtree", &generic_param, &mparam)};
auto& gbtree = dynamic_cast<gbm::GBTree&> (*p_gbm);
gbtree.Configure({{"tree_method", "hist"}});
auto p_m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix();
auto gpair = GenerateRandomGradients(kRows);
PredictionCacheEntry out_predictions;
gbtree.DoBoost(p_m.get(), &gpair, &out_predictions);
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 0);
ASSERT_EQ(1, out_predictions.version);
std::vector<float> first_iter = out_predictions.predictions.HostVector();
// Add 1 more boosted round
gbtree.DoBoost(p_m.get(), &gpair, &out_predictions);
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 0);
ASSERT_EQ(2, out_predictions.version);
// Update the cache for all rounds
out_predictions.version = 0;
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 0);
ASSERT_EQ(2, out_predictions.version);
gbtree.DoBoost(p_m.get(), &gpair, &out_predictions);
// drop the cache.
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 1, 2);
ASSERT_EQ(0, out_predictions.version);
// half open set [1, 3)
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 1, 3);
ASSERT_EQ(0, out_predictions.version);
// iteration end
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 2);
ASSERT_EQ(2, out_predictions.version);
// restart the cache when end iteration is smaller than cache version
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 1);
ASSERT_EQ(1, out_predictions.version);
ASSERT_EQ(out_predictions.predictions.HostVector(), first_iter);
}
TEST(GBTree, WrongUpdater) { TEST(GBTree, WrongUpdater) {
size_t constexpr kRows = 17; size_t constexpr kRows = 17;
size_t constexpr kCols = 15; size_t constexpr kCols = 15;

View File

@ -32,7 +32,7 @@ TEST(CpuPredictor, Basic) {
// Test predict batch // Test predict batch
PredictionCacheEntry out_predictions; PredictionCacheEntry out_predictions;
cpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0); cpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0);
ASSERT_EQ(model.trees.size(), out_predictions.version);
std::vector<float>& out_predictions_h = out_predictions.predictions.HostVector(); std::vector<float>& out_predictions_h = out_predictions.predictions.HostVector();
for (size_t i = 0; i < out_predictions.predictions.Size(); i++) { for (size_t i = 0; i < out_predictions.predictions.Size(); i++) {
ASSERT_EQ(out_predictions_h[i], 1.5); ASSERT_EQ(out_predictions_h[i], 1.5);
@ -215,7 +215,7 @@ TEST(CpuPredictor, UpdatePredictionCache) {
PredictionCacheEntry out_predictions; PredictionCacheEntry out_predictions;
// perform fair prediction on the same input data, should be equal to cached result // perform fair prediction on the same input data, should be equal to cached result
gbm->PredictBatch(dmat.get(), &out_predictions, false, 0); gbm->PredictBatch(dmat.get(), &out_predictions, false, 0, 0);
std::vector<float> &out_predictions_h = out_predictions.predictions.HostVector(); std::vector<float> &out_predictions_h = out_predictions.predictions.HostVector();
std::vector<float> &predtion_cache_from_train = predtion_cache.predictions.HostVector(); std::vector<float> &predtion_cache_from_train = predtion_cache.predictions.HostVector();

View File

@ -45,7 +45,6 @@ TEST(GPUPredictor, Basic) {
PredictionCacheEntry cpu_out_predictions; PredictionCacheEntry cpu_out_predictions;
gpu_predictor->PredictBatch(dmat.get(), &gpu_out_predictions, model, 0); gpu_predictor->PredictBatch(dmat.get(), &gpu_out_predictions, model, 0);
ASSERT_EQ(model.trees.size(), gpu_out_predictions.version);
cpu_predictor->PredictBatch(dmat.get(), &cpu_out_predictions, model, 0); cpu_predictor->PredictBatch(dmat.get(), &cpu_out_predictions, model, 0);
std::vector<float>& gpu_out_predictions_h = gpu_out_predictions.predictions.HostVector(); std::vector<float>& gpu_out_predictions_h = gpu_out_predictions.predictions.HostVector();

View File

@ -64,10 +64,10 @@ void TestTrainingPrediction(size_t rows, size_t bins,
} }
HostDeviceVector<float> from_full; HostDeviceVector<float> from_full;
learner->Predict(p_full, false, &from_full); learner->Predict(p_full, false, &from_full, 0, 0);
HostDeviceVector<float> from_hist; HostDeviceVector<float> from_hist;
learner->Predict(p_hist, false, &from_hist); learner->Predict(p_hist, false, &from_hist, 0, 0);
for (size_t i = 0; i < rows; ++i) { for (size_t i = 0; i < rows; ++i) {
EXPECT_NEAR(from_hist.ConstHostVector()[i], EXPECT_NEAR(from_hist.ConstHostVector()[i],
@ -157,20 +157,20 @@ void TestPredictionWithLesserFeatures(std::string predictor_name) {
learner->SaveConfig(&config); learner->SaveConfig(&config);
ASSERT_EQ(get<String>(config["learner"]["gradient_booster"]["gbtree_train_param"]["predictor"]), predictor_name); ASSERT_EQ(get<String>(config["learner"]["gradient_booster"]["gbtree_train_param"]["predictor"]), predictor_name);
learner->Predict(m_test, false, &prediction); learner->Predict(m_test, false, &prediction, 0, 0);
ASSERT_EQ(prediction.Size(), kRows); ASSERT_EQ(prediction.Size(), kRows);
auto m_invalid = RandomDataGenerator(kRows, kTrainCols + 1, 0.5).GenerateDMatrix(false); auto m_invalid = RandomDataGenerator(kRows, kTrainCols + 1, 0.5).GenerateDMatrix(false);
ASSERT_THROW({learner->Predict(m_invalid, false, &prediction);}, dmlc::Error); ASSERT_THROW({learner->Predict(m_invalid, false, &prediction, 0, 0);}, dmlc::Error);
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
HostDeviceVector<float> from_cpu; HostDeviceVector<float> from_cpu;
learner->SetParam("predictor", "cpu_predictor"); learner->SetParam("predictor", "cpu_predictor");
learner->Predict(m_test, false, &from_cpu); learner->Predict(m_test, false, &from_cpu, 0, 0);
HostDeviceVector<float> from_cuda; HostDeviceVector<float> from_cuda;
learner->SetParam("predictor", "gpu_predictor"); learner->SetParam("predictor", "gpu_predictor");
learner->Predict(m_test, false, &from_cuda); learner->Predict(m_test, false, &from_cuda, 0, 0);
auto const& h_cpu = from_cpu.ConstHostVector(); auto const& h_cpu = from_cpu.ConstHostVector();
auto const& h_gpu = from_cuda.ConstHostVector(); auto const& h_gpu = from_cuda.ConstHostVector();

View File

@ -221,9 +221,10 @@ TEST(Learner, MultiThreadedPredict) {
auto &entry = learner->GetThreadLocal().prediction_entry; auto &entry = learner->GetThreadLocal().prediction_entry;
HostDeviceVector<float> predictions; HostDeviceVector<float> predictions;
for (size_t iter = 0; iter < kIters; ++iter) { for (size_t iter = 0; iter < kIters; ++iter) {
learner->Predict(p_data, false, &entry.predictions); learner->Predict(p_data, false, &entry.predictions, 0, 0);
learner->Predict(p_data, false, &predictions, 0, true); // leaf
learner->Predict(p_data, false, &predictions, 0, false, true); // contribs learner->Predict(p_data, false, &predictions, 0, 0, false, true); // leaf
learner->Predict(p_data, false, &predictions, 0, 0, false, false, true); // contribs
} }
}); });
} }

View File

@ -112,17 +112,24 @@ def _test_cupy_metainfo(DMatrixT):
@pytest.mark.skipif(**tm.no_sklearn()) @pytest.mark.skipif(**tm.no_sklearn())
def test_cupy_training_with_sklearn(): def test_cupy_training_with_sklearn():
import cupy as cp import cupy as cp
np.random.seed(1) np.random.seed(1)
cp.random.seed(1) cp.random.seed(1)
X = cp.random.randn(50, 10, dtype='float32') X = cp.random.randn(50, 10, dtype="float32")
y = (cp.random.randn(50, dtype='float32') > 0).astype('int8') y = (cp.random.randn(50, dtype="float32") > 0).astype("int8")
weights = np.random.random(50) + 1 weights = np.random.random(50) + 1
cupy_weights = cp.array(weights) cupy_weights = cp.array(weights)
base_margin = np.random.random(50) base_margin = np.random.random(50)
cupy_base_margin = cp.array(base_margin) cupy_base_margin = cp.array(base_margin)
clf = xgb.XGBClassifier(gpu_id=0, tree_method='gpu_hist', use_label_encoder=False) clf = xgb.XGBClassifier(gpu_id=0, tree_method="gpu_hist", use_label_encoder=False)
clf.fit(X, y, sample_weight=cupy_weights, base_margin=cupy_base_margin, eval_set=[(X, y)]) clf.fit(
X,
y,
sample_weight=cupy_weights,
base_margin=cupy_base_margin,
eval_set=[(X, y)],
)
pred = clf.predict(X) pred = clf.predict(X)
assert np.array_equal(np.unique(pred), np.array([0, 1])) assert np.array_equal(np.unique(pred), np.array([0, 1]))

View File

@ -17,6 +17,8 @@ if sys.platform.startswith("win"):
sys.path.append("tests/python") sys.path.append("tests/python")
from test_with_dask import run_empty_dmatrix_reg # noqa from test_with_dask import run_empty_dmatrix_reg # noqa
from test_with_dask import run_boost_from_prediction # noqa
from test_with_dask import run_dask_classifier # noqa
from test_with_dask import run_empty_dmatrix_cls # noqa from test_with_dask import run_empty_dmatrix_cls # noqa
from test_with_dask import _get_client_workers # noqa from test_with_dask import _get_client_workers # noqa
from test_with_dask import generate_array # noqa from test_with_dask import generate_array # noqa
@ -132,9 +134,9 @@ def run_gpu_hist(
num_rounds: int, num_rounds: int,
dataset: tm.TestDataset, dataset: tm.TestDataset,
DMatrixT: Type, DMatrixT: Type,
client: Client client: Client,
) -> None: ) -> None:
params['tree_method'] = 'gpu_hist' params["tree_method"] = "gpu_hist"
params = dataset.set_params(params) params = dataset.set_params(params)
# It doesn't make sense to distribute a completely # It doesn't make sense to distribute a completely
# empty dataset. # empty dataset.
@ -143,26 +145,40 @@ def run_gpu_hist(
chunk = 128 chunk = 128
X = to_cp(dataset.X, DMatrixT) X = to_cp(dataset.X, DMatrixT)
X = da.from_array(X, X = da.from_array(X, chunks=(chunk, dataset.X.shape[1]))
chunks=(chunk, dataset.X.shape[1]))
y = to_cp(dataset.y, DMatrixT) y = to_cp(dataset.y, DMatrixT)
y = da.from_array(y, chunks=(chunk, )) y = da.from_array(y, chunks=(chunk,))
if dataset.w is not None: if dataset.w is not None:
w = to_cp(dataset.w, DMatrixT) w = to_cp(dataset.w, DMatrixT)
w = da.from_array(w, chunks=(chunk, )) w = da.from_array(w, chunks=(chunk,))
else: else:
w = None w = None
if DMatrixT is dxgb.DaskDeviceQuantileDMatrix: if DMatrixT is dxgb.DaskDeviceQuantileDMatrix:
m = DMatrixT(client, data=X, label=y, weight=w, m = DMatrixT(
max_bin=params.get('max_bin', 256)) client, data=X, label=y, weight=w, max_bin=params.get("max_bin", 256)
)
else: else:
m = DMatrixT(client, data=X, label=y, weight=w) m = DMatrixT(client, data=X, label=y, weight=w)
history = dxgb.train(client, params=params, dtrain=m, history = dxgb.train(
client,
params=params,
dtrain=m,
num_boost_round=num_rounds, num_boost_round=num_rounds,
evals=[(m, 'train')])['history'] evals=[(m, "train")],
)["history"]
note(history) note(history)
assert tm.non_increasing(history['train'][dataset.metric]) assert tm.non_increasing(history["train"][dataset.metric])
def test_boost_from_prediction(local_cuda_cluster: LocalCUDACluster) -> None:
import cudf
from sklearn.datasets import load_breast_cancer
with Client(local_cuda_cluster) as client:
X_, y_ = load_breast_cancer(return_X_y=True)
X = dd.from_array(X_, chunksize=100).map_partitions(cudf.from_pandas)
y = dd.from_array(y_, chunksize=100).map_partitions(cudf.from_pandas)
run_boost_from_prediction(X, y, "gpu_hist", client)
class TestDistributedGPU: class TestDistributedGPU:
@ -246,6 +262,20 @@ class TestDistributedGPU:
dump = booster.get_dump(dump_format='json') dump = booster.get_dump(dump_format='json')
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1 assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
@pytest.mark.skipif(**tm.no_cudf())
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.parametrize("model", ["boosting"])
def test_dask_classifier(self, model, local_cuda_cluster: LocalCUDACluster) -> None:
import dask_cudf
with Client(local_cuda_cluster) as client:
X_, y_, w_ = generate_array(with_weights=True)
y_ = (y_ * 10).astype(np.int32)
X = dask_cudf.from_dask_dataframe(dd.from_dask_array(X_))
y = dask_cudf.from_dask_dataframe(dd.from_dask_array(y_))
w = dask_cudf.from_dask_dataframe(dd.from_dask_array(w_))
run_dask_classifier(X, y, w, model, client)
@pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_dask_cuda()) @pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.mgpu @pytest.mark.mgpu

View File

@ -434,7 +434,13 @@ class TestModels:
booster[...:end] = booster booster[...:end] = booster
sliced_0 = booster[1:3] sliced_0 = booster[1:3]
np.testing.assert_allclose(
booster.predict(dtrain, iteration_range=(1, 3)), sliced_0.predict(dtrain)
)
sliced_1 = booster[3:7] sliced_1 = booster[3:7]
np.testing.assert_allclose(
booster.predict(dtrain, iteration_range=(3, 7)), sliced_1.predict(dtrain)
)
predt_0 = sliced_0.predict(dtrain, output_margin=True) predt_0 = sliced_0.predict(dtrain, output_margin=True)
predt_1 = sliced_1.predict(dtrain, output_margin=True) predt_1 = sliced_1.predict(dtrain, output_margin=True)

View File

@ -47,30 +47,27 @@ def run_predict_leaf(predictor):
empty_leaf = booster.predict(empty, pred_leaf=True) empty_leaf = booster.predict(empty, pred_leaf=True)
assert empty_leaf.shape[0] == 0 assert empty_leaf.shape[0] == 0
leaf = booster.predict(m, pred_leaf=True) leaf = booster.predict(m, pred_leaf=True, strict_shape=True)
assert leaf.shape[0] == rows assert leaf.shape[0] == rows
assert leaf.shape[1] == classes * num_parallel_tree * num_boost_round assert leaf.shape[1] == num_boost_round
assert leaf.shape[2] == classes
assert leaf.shape[3] == num_parallel_tree
for i in range(rows): for i in range(rows):
row = leaf[i, ...]
for j in range(num_boost_round): for j in range(num_boost_round):
start = classes * num_parallel_tree * j for k in range(classes):
end = classes * num_parallel_tree * (j + 1) tree_group = leaf[i, j, k, :]
layer = row[start: end]
for c in range(classes):
tree_group = layer[c * num_parallel_tree: (c + 1) * num_parallel_tree]
assert tree_group.shape[0] == num_parallel_tree assert tree_group.shape[0] == num_parallel_tree
# no subsampling so tree in same forest should output same # No sampling, all trees within forest are the same
# leaf.
assert np.all(tree_group == tree_group[0]) assert np.all(tree_group == tree_group[0])
ntree_limit = 2 ntree_limit = 2
sliced = booster.predict( sliced = booster.predict(
m, pred_leaf=True, ntree_limit=num_parallel_tree * ntree_limit m, pred_leaf=True, ntree_limit=num_parallel_tree * ntree_limit, strict_shape=True
) )
first = sliced[0, ...] first = sliced[0, ...]
assert first.shape[0] == classes * num_parallel_tree * ntree_limit assert np.prod(first.shape) == classes * num_parallel_tree * ntree_limit
return leaf return leaf
@ -78,6 +75,23 @@ def test_predict_leaf():
run_predict_leaf('cpu_predictor') run_predict_leaf('cpu_predictor')
def test_predict_shape():
from sklearn.datasets import load_boston
X, y = load_boston(return_X_y=True)
reg = xgb.XGBRegressor(n_estimators=1)
reg.fit(X, y)
predt = reg.get_booster().predict(xgb.DMatrix(X), strict_shape=True)
assert len(predt.shape) == 2
assert predt.shape[0] == X.shape[0]
assert predt.shape[1] == 1
contrib = reg.get_booster().predict(
xgb.DMatrix(X), pred_contribs=True, strict_shape=True
)
assert len(contrib.shape) == 3
assert contrib.shape[1] == 1
class TestInplacePredict: class TestInplacePredict:
'''Tests for running inplace prediction''' '''Tests for running inplace prediction'''
@classmethod @classmethod
@ -92,8 +106,7 @@ class TestInplacePredict:
dtrain = xgb.DMatrix(cls.X, cls.y) dtrain = xgb.DMatrix(cls.X, cls.y)
cls.booster = xgb.train({'tree_method': 'hist'}, cls.booster = xgb.train({'tree_method': 'hist'}, dtrain, num_boost_round=10)
dtrain, num_boost_round=10)
cls.test = xgb.DMatrix(cls.X[:10, ...]) cls.test = xgb.DMatrix(cls.X[:10, ...])

View File

@ -159,12 +159,9 @@ def test_dask_predict_shape_infer(client: "Client") -> None:
assert prediction.shape[1] == 3 assert prediction.shape[1] == 3
@pytest.mark.parametrize("tree_method", ["hist", "approx"]) def run_boost_from_prediction(
def test_boost_from_prediction(tree_method: str, client: "Client") -> None: X: xgb.dask._DaskCollection, y: xgb.dask._DaskCollection, tree_method: str, client: "Client"
from sklearn.datasets import load_breast_cancer ) -> None:
X_, y_ = load_breast_cancer(return_X_y=True)
X, y = dd.from_array(X_, chunksize=100), dd.from_array(y_, chunksize=100)
model_0 = xgb.dask.DaskXGBClassifier( model_0 = xgb.dask.DaskXGBClassifier(
learning_rate=0.3, random_state=0, n_estimators=4, learning_rate=0.3, random_state=0, n_estimators=4,
tree_method=tree_method) tree_method=tree_method)
@ -202,6 +199,30 @@ def test_boost_from_prediction(tree_method: str, client: "Client") -> None:
assert margined_res[i] < unmargined_res[i] assert margined_res[i] < unmargined_res[i]
@pytest.mark.parametrize("tree_method", ["hist", "approx"])
def test_boost_from_prediction(tree_method: str, client: "Client") -> None:
from sklearn.datasets import load_breast_cancer
X_, y_ = load_breast_cancer(return_X_y=True)
X, y = dd.from_array(X_, chunksize=100), dd.from_array(y_, chunksize=100)
run_boost_from_prediction(X, y, tree_method, client)
def test_inplace_predict(client: "Client") -> None:
from sklearn.datasets import load_boston
X_, y_ = load_boston(return_X_y=True)
X, y = dd.from_array(X_, chunksize=32), dd.from_array(y_, chunksize=32)
reg = xgb.dask.DaskXGBRegressor(n_estimators=4).fit(X, y)
booster = reg.get_booster()
base_margin = y
inplace = xgb.dask.inplace_predict(
client, booster, X, base_margin=base_margin
).compute()
Xy = xgb.dask.DaskDMatrix(client, X, base_margin=base_margin)
copied = xgb.dask.predict(client, booster, Xy).compute()
np.testing.assert_allclose(inplace, copied)
def test_dask_missing_value_reg(client: "Client") -> None: def test_dask_missing_value_reg(client: "Client") -> None:
X_0 = np.ones((20 // 2, kCols)) X_0 = np.ones((20 // 2, kCols))
X_1 = np.zeros((20 // 2, kCols)) X_1 = np.zeros((20 // 2, kCols))
@ -288,10 +309,13 @@ def test_dask_regressor(model: str, client: "Client") -> None:
assert forest == 2 assert forest == 2
@pytest.mark.parametrize("model", ["boosting", "rf"]) def run_dask_classifier(
def test_dask_classifier(model: str, client: "Client") -> None: X: xgb.dask._DaskCollection,
X, y, w = generate_array(with_weights=True) y: xgb.dask._DaskCollection,
y = (y * 10).astype(np.int32) w: xgb.dask._DaskCollection,
model: str,
client: "Client",
) -> None:
if model == "boosting": if model == "boosting":
classifier = xgb.dask.DaskXGBClassifier( classifier = xgb.dask.DaskXGBClassifier(
verbosity=1, n_estimators=2, eval_metric="merror" verbosity=1, n_estimators=2, eval_metric="merror"
@ -306,14 +330,13 @@ def test_dask_classifier(model: str, client: "Client") -> None:
classifier.client = client classifier.client = client
classifier.fit(X, y, sample_weight=w, eval_set=[(X, y)]) classifier.fit(X, y, sample_weight=w, eval_set=[(X, y)])
prediction = classifier.predict(X) prediction = classifier.predict(X).compute()
assert prediction.ndim == 1 assert prediction.ndim == 1
assert prediction.shape[0] == kRows assert prediction.shape[0] == kRows
history = classifier.evals_result() history = classifier.evals_result()
assert isinstance(prediction, da.Array)
assert isinstance(history, dict) assert isinstance(history, dict)
assert list(history.keys())[0] == "validation_0" assert list(history.keys())[0] == "validation_0"
@ -332,7 +355,7 @@ def test_dask_classifier(model: str, client: "Client") -> None:
assert forest == 2 assert forest == 2
# Test .predict_proba() # Test .predict_proba()
probas = classifier.predict_proba(X) probas = classifier.predict_proba(X).compute()
assert classifier.n_classes_ == 10 assert classifier.n_classes_ == 10
assert probas.ndim == 2 assert probas.ndim == 2
assert probas.shape[0] == kRows assert probas.shape[0] == kRows
@ -341,18 +364,33 @@ def test_dask_classifier(model: str, client: "Client") -> None:
cls_booster = classifier.get_booster() cls_booster = classifier.get_booster()
single_node_proba = cls_booster.inplace_predict(X.compute()) single_node_proba = cls_booster.inplace_predict(X.compute())
np.testing.assert_allclose(single_node_proba, probas.compute()) # test shared by CPU and GPU
if isinstance(single_node_proba, np.ndarray):
np.testing.assert_allclose(single_node_proba, probas)
else:
import cupy
cupy.testing.assert_allclose(single_node_proba, probas)
# Test with dataframe. # Test with dataframe, not shared with GPU as cupy doesn't work well with da.unique.
X_d = dd.from_dask_array(X) if isinstance(X, da.Array):
y_d = dd.from_dask_array(y) X_d: dd.DataFrame = X.to_dask_dataframe()
classifier.fit(X_d, y_d)
assert classifier.n_classes_ == 10 assert classifier.n_classes_ == 10
prediction = classifier.predict(X_d).compute() prediction_df = classifier.predict(X_d).compute()
assert prediction.ndim == 1 assert prediction_df.ndim == 1
assert prediction.shape[0] == kRows assert prediction_df.shape[0] == kRows
np.testing.assert_allclose(prediction_df, prediction)
probas = classifier.predict_proba(X).compute()
np.testing.assert_allclose(single_node_proba, probas)
@pytest.mark.parametrize("model", ["boosting", "rf"])
def test_dask_classifier(model: str, client: "Client") -> None:
X, y, w = generate_array(with_weights=True)
y = (y * 10).astype(np.int32)
run_dask_classifier(X, y, w, model, client)
@pytest.mark.skipif(**tm.no_sklearn()) @pytest.mark.skipif(**tm.no_sklearn())
@ -913,9 +951,9 @@ class TestWithDask:
train = xgb.dask.DaskDMatrix(client, dX, dy) train = xgb.dask.DaskDMatrix(client, dX, dy)
dX = dd.from_array(X) dX = dd.from_array(X)
dX = client.persist(dX, workers={dX: workers[1]}) dX = client.persist(dX, workers=workers[1])
dy = dd.from_array(y) dy = dd.from_array(y)
dy = client.persist(dy, workers={dy: workers[1]}) dy = client.persist(dy, workers=workers[1])
valid = xgb.dask.DaskDMatrix(client, dX, dy) valid = xgb.dask.DaskDMatrix(client, dX, dy)
merged = xgb.dask._get_workers_from_data(train, evals=[(valid, 'Valid')]) merged = xgb.dask._get_workers_from_data(train, evals=[(valid, 'Valid')])
@ -1060,6 +1098,16 @@ class TestWithDask:
assert_shape(shap.shape) assert_shape(shap.shape)
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5) assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5)
X = dd.from_dask_array(X).repartition(npartitions=32)
y = dd.from_dask_array(y).repartition(npartitions=32)
shap_df = xgb.dask.predict(
client, booster, X, pred_contribs=True, validate_features=False
).compute()
assert_shape(shap_df.shape)
assert np.allclose(
np.sum(shap_df, axis=len(shap_df.shape) - 1), margin, 1e-5, 1e-5
)
def run_shap_cls_sklearn(self, X: Any, y: Any, client: "Client") -> None: def run_shap_cls_sklearn(self, X: Any, y: Any, client: "Client") -> None:
X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32) X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32)
cls = xgb.dask.DaskXGBClassifier(n_estimators=4) cls = xgb.dask.DaskXGBClassifier(n_estimators=4)