[breaking] Add prediction fucntion for DMatrix and use inplace predict for dask. (#6668)
* Add a new API function for predicting on `DMatrix`. This function aligns with rest of the `XGBoosterPredictFrom*` functions on semantic of function arguments. * Purge `ntree_limit` from libxgboost, use iteration instead. * [dask] Use `inplace_predict` by default for dask sklearn models. * [dask] Run prediction shape inference on worker instead of client. The breaking change is in the Python sklearn `apply` function, I made it to be consistent with other prediction functions where `best_iteration` is used by default.
This commit is contained in:
parent
dbb5208a0a
commit
4656b09d5d
@ -704,8 +704,9 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle,
|
||||
const char *evnames[],
|
||||
bst_ulong len,
|
||||
const char **out_result);
|
||||
|
||||
/*!
|
||||
* \brief make prediction based on dmat
|
||||
* \brief make prediction based on dmat (deprecated, use `XGBoosterPredictFromDMatrix` instead)
|
||||
* \param handle handle
|
||||
* \param dmat data matrix
|
||||
* \param option_mask bit-mask of options taken in prediction, possible values
|
||||
@ -734,6 +735,165 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
||||
int training,
|
||||
bst_ulong *out_len,
|
||||
const float **out_result);
|
||||
/*!
|
||||
* \brief Make prediction from DMatrix, replacing `XGBoosterPredict`.
|
||||
*
|
||||
* \param handle Booster handle
|
||||
* \param dmat DMatrix handle
|
||||
* \param c_json_config String encoded predict configuration in JSON format.
|
||||
*
|
||||
* "type": [0, 5]
|
||||
* 0: normal prediction
|
||||
* 1: output margin
|
||||
* 2: predict contribution
|
||||
* 3: predict approxmated contribution
|
||||
* 4: predict feature interaction
|
||||
* 5: predict leaf
|
||||
* "training": bool
|
||||
* Whether the prediction function is used as part of a training loop. **Not used
|
||||
* for inplace prediction**.
|
||||
*
|
||||
* Prediction can be run in 2 scenarios:
|
||||
* 1. Given data matrix X, obtain prediction y_pred from the model.
|
||||
* 2. Obtain the prediction for computing gradients. For example, DART booster performs dropout
|
||||
* during training, and the prediction result will be different from the one obtained by normal
|
||||
* inference step due to dropped trees.
|
||||
* Set training=false for the first scenario. Set training=true for the second
|
||||
* scenario. The second scenario applies when you are defining a custom objective
|
||||
* function.
|
||||
* "iteration_begin": int
|
||||
* Beginning iteration of prediction.
|
||||
* "iteration_end": int
|
||||
* End iteration of prediction. Set to 0 this will become the size of tree model.
|
||||
* "strict_shape": bool
|
||||
* Whether should we reshape the output with stricter rules. If set to true,
|
||||
* normal/margin/contrib/interaction predict will output consistent shape
|
||||
* disregarding the use of multi-class model, and leaf prediction will output 4-dim
|
||||
* array representing: (n_samples, n_iterations, n_classes, n_trees_in_forest)
|
||||
*
|
||||
* Run a normal prediction with strict output shape, 2 dim for softprob , 1 dim for others.
|
||||
* \code
|
||||
* {
|
||||
* "type": 0,
|
||||
* "training": False,
|
||||
* "iteration_begin": 0,
|
||||
* "iteration_end": 0,
|
||||
* "strict_shape": true,
|
||||
* }
|
||||
* \endcode
|
||||
*
|
||||
* \param out_shape Shape of output prediction (copy before use).
|
||||
* \param out_dim Dimension of output prediction.
|
||||
* \param out_result Buffer storing prediction value (copy before use).
|
||||
*
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle,
|
||||
DMatrixHandle dmat,
|
||||
char const* c_json_config,
|
||||
bst_ulong const **out_shape,
|
||||
bst_ulong *out_dim,
|
||||
float const **out_result);
|
||||
/*
|
||||
* \brief Inplace prediction from CPU dense matrix.
|
||||
*
|
||||
* \param handle Booster handle.
|
||||
* \param values JSON encoded __array_interface__ to values.
|
||||
* \param c_json_config See `XGBoosterPredictFromDMatrix` for more info.
|
||||
*
|
||||
* Additional fields for inplace prediction are:
|
||||
* "missing": float
|
||||
*
|
||||
* \param m An optional (NULL if not available) proxy DMatrix instance
|
||||
* storing meta info.
|
||||
*
|
||||
* \param out_shape See `XGBoosterPredictFromDMatrix` for more info.
|
||||
* \param out_dim See `XGBoosterPredictFromDMatrix` for more info.
|
||||
* \param out_result See `XGBoosterPredictFromDMatrix` for more info.
|
||||
*
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle,
|
||||
char const *values,
|
||||
char const *c_json_config,
|
||||
DMatrixHandle m,
|
||||
bst_ulong const **out_shape,
|
||||
bst_ulong *out_dim,
|
||||
const float **out_result);
|
||||
|
||||
/*
|
||||
* \brief Inplace prediction from CPU CSR matrix.
|
||||
*
|
||||
* \param handle Booster handle.
|
||||
* \param indptr JSON encoded __array_interface__ to row pointer in CSR.
|
||||
* \param indices JSON encoded __array_interface__ to column indices in CSR.
|
||||
* \param values JSON encoded __array_interface__ to values in CSR..
|
||||
* \param ncol Number of features in data.
|
||||
* \param c_json_config See `XGBoosterPredictFromDMatrix` for more info.
|
||||
* Additional fields for inplace prediction are:
|
||||
* "missing": float
|
||||
*
|
||||
* \param m An optional (NULL if not available) proxy DMatrix instance
|
||||
* storing meta info.
|
||||
*
|
||||
* \param out_shape See `XGBoosterPredictFromDMatrix` for more info.
|
||||
* \param out_dim See `XGBoosterPredictFromDMatrix` for more info.
|
||||
* \param out_result See `XGBoosterPredictFromDMatrix` for more info.
|
||||
*
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, char const *indptr,
|
||||
char const *indices, char const *values,
|
||||
bst_ulong ncol,
|
||||
char const *c_json_config, DMatrixHandle m,
|
||||
bst_ulong const **out_shape,
|
||||
bst_ulong *out_dim,
|
||||
const float **out_result);
|
||||
|
||||
/*
|
||||
* \brief Inplace prediction from CUDA Dense matrix (cupy in Python).
|
||||
*
|
||||
* \param handle Booster handle
|
||||
* \param values JSON encoded __cuda_array_interface__ to values.
|
||||
* \param c_json_config See `XGBoosterPredictFromDMatrix` for more info.
|
||||
* Additional fields for inplace prediction are:
|
||||
* "missing": float
|
||||
*
|
||||
* \param m An optional (NULL if not available) proxy DMatrix instance
|
||||
* storing meta info.
|
||||
* \param out_shape See `XGBoosterPredictFromDMatrix` for more info.
|
||||
* \param out_dim See `XGBoosterPredictFromDMatrix` for more info.
|
||||
* \param out_result See `XGBoosterPredictFromDMatrix` for more info.
|
||||
*
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterPredictFromCudaArray(
|
||||
BoosterHandle handle, char const *values, char const *c_json_config,
|
||||
DMatrixHandle m, bst_ulong const **out_shape, bst_ulong *out_dim,
|
||||
const float **out_result);
|
||||
|
||||
/*
|
||||
* \brief Inplace prediction from CUDA dense dataframe (cuDF in Python).
|
||||
*
|
||||
* \param handle Booster handle
|
||||
* \param values List of __cuda_array_interface__ for all columns encoded in JSON list.
|
||||
* \param c_json_config See `XGBoosterPredictFromDMatrix` for more info.
|
||||
* Additional fields for inplace prediction are:
|
||||
* "missing": float
|
||||
*
|
||||
* \param m An optional (NULL if not available) proxy DMatrix instance
|
||||
* storing meta info.
|
||||
* \param out_shape See `XGBoosterPredictFromDMatrix` for more info.
|
||||
* \param out_dim See `XGBoosterPredictFromDMatrix` for more info.
|
||||
* \param out_result See `XGBoosterPredictFromDMatrix` for more info.
|
||||
*
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
XGB_DLL int XGBoosterPredictFromCudaColumnar(
|
||||
BoosterHandle handle, char const *values, char const *c_json_config,
|
||||
DMatrixHandle m, bst_ulong const **out_shape, bst_ulong *out_dim,
|
||||
const float **out_result);
|
||||
|
||||
|
||||
/*
|
||||
* ========================== Begin Serialization APIs =========================
|
||||
|
||||
@ -63,7 +63,7 @@ class GradientBooster : public Model, public Configurable {
|
||||
/*!
|
||||
* \brief Slice a model using boosting index. The slice m:n indicates taking all trees
|
||||
* that were fit during the boosting rounds m, (m+1), (m+2), ..., (n-1).
|
||||
* \param layer_begin Begining of boosted tree layer used for prediction.
|
||||
* \param layer_begin Beginning of boosted tree layer used for prediction.
|
||||
* \param layer_end End of booster layer. 0 means do not limit trees.
|
||||
* \param out Output gradient booster
|
||||
*/
|
||||
@ -99,15 +99,14 @@ class GradientBooster : public Model, public Configurable {
|
||||
* \param out_preds output vector to hold the predictions
|
||||
* \param training Whether the prediction value is used for training. For dart booster
|
||||
* drop out is performed during training.
|
||||
* \param ntree_limit limit the number of trees used in prediction,
|
||||
* when it equals 0, this means we do not limit
|
||||
* number of trees, this parameter is only valid
|
||||
* for gbtree, but not for gblinear
|
||||
* \param layer_begin Beginning of boosted tree layer used for prediction.
|
||||
* \param layer_end End of booster layer. 0 means do not limit trees.
|
||||
*/
|
||||
virtual void PredictBatch(DMatrix* dmat,
|
||||
PredictionCacheEntry* out_preds,
|
||||
bool training,
|
||||
unsigned ntree_limit = 0) = 0;
|
||||
unsigned layer_begin,
|
||||
unsigned layer_end) = 0;
|
||||
|
||||
/*!
|
||||
* \brief Inplace prediction.
|
||||
@ -115,7 +114,7 @@ class GradientBooster : public Model, public Configurable {
|
||||
* \param x A type erased data adapter.
|
||||
* \param missing Missing value in the data.
|
||||
* \param [in,out] out_preds The output preds.
|
||||
* \param layer_begin (Optional) Begining of boosted tree layer used for prediction.
|
||||
* \param layer_begin (Optional) Beginning of boosted tree layer used for prediction.
|
||||
* \param layer_end (Optional) End of booster layer. 0 means do not limit trees.
|
||||
*/
|
||||
virtual void InplacePredict(dmlc::any const &, std::shared_ptr<DMatrix>, float,
|
||||
@ -132,44 +131,45 @@ class GradientBooster : public Model, public Configurable {
|
||||
*
|
||||
* \param inst the instance you want to predict
|
||||
* \param out_preds output vector to hold the predictions
|
||||
* \param ntree_limit limit the number of trees used in prediction
|
||||
* \param layer_begin Beginning of boosted tree layer used for prediction.
|
||||
* \param layer_end End of booster layer. 0 means do not limit trees.
|
||||
* \sa Predict
|
||||
*/
|
||||
virtual void PredictInstance(const SparsePage::Inst& inst,
|
||||
std::vector<bst_float>* out_preds,
|
||||
unsigned ntree_limit = 0) = 0;
|
||||
unsigned layer_begin, unsigned layer_end) = 0;
|
||||
/*!
|
||||
* \brief predict the leaf index of each tree, the output will be nsample * ntree vector
|
||||
* this is only valid in gbtree predictor
|
||||
* \param dmat feature matrix
|
||||
* \param out_preds output vector to hold the predictions
|
||||
* \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means
|
||||
* we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear
|
||||
* \param layer_begin Beginning of boosted tree layer used for prediction.
|
||||
* \param layer_end End of booster layer. 0 means do not limit trees.
|
||||
*/
|
||||
virtual void PredictLeaf(DMatrix* dmat,
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
unsigned ntree_limit = 0) = 0;
|
||||
virtual void PredictLeaf(DMatrix *dmat,
|
||||
HostDeviceVector<bst_float> *out_preds,
|
||||
unsigned layer_begin, unsigned layer_end) = 0;
|
||||
|
||||
/*!
|
||||
* \brief feature contributions to individual predictions; the output will be a vector
|
||||
* of length (nfeats + 1) * num_output_group * nsample, arranged in that order
|
||||
* \param dmat feature matrix
|
||||
* \param out_contribs output vector to hold the contributions
|
||||
* \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means
|
||||
* we do not limit number of trees
|
||||
* \param layer_begin Beginning of boosted tree layer used for prediction.
|
||||
* \param layer_end End of booster layer. 0 means do not limit trees.
|
||||
* \param approximate use a faster (inconsistent) approximation of SHAP values
|
||||
* \param condition condition on the condition_feature (0=no, -1=cond off, 1=cond on).
|
||||
* \param condition_feature feature to condition on (i.e. fix) during calculations
|
||||
*/
|
||||
virtual void PredictContribution(DMatrix* dmat,
|
||||
HostDeviceVector<bst_float>* out_contribs,
|
||||
unsigned ntree_limit = 0,
|
||||
unsigned layer_begin, unsigned layer_end,
|
||||
bool approximate = false, int condition = 0,
|
||||
unsigned condition_feature = 0) = 0;
|
||||
|
||||
virtual void PredictInteractionContributions(DMatrix* dmat,
|
||||
HostDeviceVector<bst_float>* out_contribs,
|
||||
unsigned ntree_limit, bool approximate) = 0;
|
||||
virtual void PredictInteractionContributions(
|
||||
DMatrix *dmat, HostDeviceVector<bst_float> *out_contribs,
|
||||
unsigned layer_begin, unsigned layer_end, bool approximate) = 0;
|
||||
|
||||
/*!
|
||||
* \brief dump the model in the requested format
|
||||
|
||||
@ -113,8 +113,8 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
|
||||
* \param data input data
|
||||
* \param output_margin whether to only predict margin value instead of transformed prediction
|
||||
* \param out_preds output vector that stores the prediction
|
||||
* \param ntree_limit limit number of trees used for boosted tree
|
||||
* predictor, when it equals 0, this means we are using all the trees
|
||||
* \param layer_begin Beginning of boosted tree layer used for prediction.
|
||||
* \param layer_end End of booster layer. 0 means do not limit trees.
|
||||
* \param training Whether the prediction result is used for training
|
||||
* \param pred_leaf whether to only predict the leaf index of each tree in a boosted tree predictor
|
||||
* \param pred_contribs whether to only predict the feature contributions
|
||||
@ -124,7 +124,8 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
|
||||
virtual void Predict(std::shared_ptr<DMatrix> data,
|
||||
bool output_margin,
|
||||
HostDeviceVector<bst_float> *out_preds,
|
||||
unsigned ntree_limit = 0,
|
||||
unsigned layer_begin,
|
||||
unsigned layer_end,
|
||||
bool training = false,
|
||||
bool pred_leaf = false,
|
||||
bool pred_contribs = false,
|
||||
@ -140,7 +141,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
|
||||
* \param type Prediction type.
|
||||
* \param missing Missing value in the data.
|
||||
* \param [in,out] out_preds Pointer to output prediction vector.
|
||||
* \param layer_begin Begining of boosted tree layer used for prediction.
|
||||
* \param layer_begin Beginning of boosted tree layer used for prediction.
|
||||
* \param layer_end End of booster layer. 0 means do not limit trees.
|
||||
*/
|
||||
virtual void InplacePredict(dmlc::any const &x,
|
||||
|
||||
@ -127,12 +127,11 @@ class Predictor {
|
||||
* \param [in,out] out_preds The output preds.
|
||||
* \param model The model to predict from.
|
||||
* \param tree_begin The tree begin index.
|
||||
* \param ntree_limit (Optional) The ntree limit. 0 means do not
|
||||
* limit trees.
|
||||
* \param tree_end The tree end index.
|
||||
*/
|
||||
virtual void PredictBatch(DMatrix* dmat, PredictionCacheEntry* out_preds,
|
||||
const gbm::GBTreeModel& model, int tree_begin,
|
||||
uint32_t const ntree_limit = 0) const = 0;
|
||||
const gbm::GBTreeModel& model, uint32_t tree_begin,
|
||||
uint32_t tree_end = 0) const = 0;
|
||||
|
||||
/**
|
||||
* \brief Inplace prediction.
|
||||
@ -140,7 +139,7 @@ class Predictor {
|
||||
* \param model The model to predict from.
|
||||
* \param missing Missing value in the data.
|
||||
* \param [in,out] out_preds The output preds.
|
||||
* \param tree_begin (Optional) Begining of boosted trees used for prediction.
|
||||
* \param tree_begin (Optional) Beginning of boosted trees used for prediction.
|
||||
* \param tree_end (Optional) End of booster trees. 0 means do not limit trees.
|
||||
*
|
||||
* \return True if the data can be handled by current predictor, false otherwise.
|
||||
@ -159,13 +158,13 @@ class Predictor {
|
||||
* \param inst The instance to predict.
|
||||
* \param [in,out] out_preds The output preds.
|
||||
* \param model The model to predict from
|
||||
* \param ntree_limit (Optional) The ntree limit.
|
||||
* \param tree_end (Optional) The tree end index.
|
||||
*/
|
||||
|
||||
virtual void PredictInstance(const SparsePage::Inst& inst,
|
||||
std::vector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model,
|
||||
unsigned ntree_limit = 0) const = 0;
|
||||
unsigned tree_end = 0) const = 0;
|
||||
|
||||
/**
|
||||
* \brief predict the leaf index of each tree, the output will be nsample *
|
||||
@ -174,18 +173,14 @@ class Predictor {
|
||||
* \param [in,out] dmat The input feature matrix.
|
||||
* \param [in,out] out_preds The output preds.
|
||||
* \param model Model to make predictions from.
|
||||
* \param ntree_limit (Optional) The ntree limit.
|
||||
* \param tree_end (Optional) The tree end index.
|
||||
*/
|
||||
|
||||
virtual void PredictLeaf(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model,
|
||||
unsigned ntree_limit = 0) const = 0;
|
||||
unsigned tree_end = 0) const = 0;
|
||||
|
||||
/**
|
||||
* \fn virtual void Predictor::PredictContribution( DMatrix* dmat,
|
||||
* std::vector<bst_float>* out_contribs, const gbm::GBTreeModel& model,
|
||||
* unsigned ntree_limit = 0) = 0;
|
||||
*
|
||||
* \brief feature contributions to individual predictions; the output will be
|
||||
* a vector of length (nfeats + 1) * num_output_group * nsample, arranged in
|
||||
* that order.
|
||||
@ -193,7 +188,7 @@ class Predictor {
|
||||
* \param [in,out] dmat The input feature matrix.
|
||||
* \param [in,out] out_contribs The output feature contribs.
|
||||
* \param model Model to make predictions from.
|
||||
* \param ntree_limit (Optional) The ntree limit.
|
||||
* \param tree_end The tree end index.
|
||||
* \param tree_weights (Optional) Weights to multiply each tree by.
|
||||
* \param approximate Use fast approximate algorithm.
|
||||
* \param condition Condition on the condition_feature (0=no, -1=cond off, 1=cond on).
|
||||
@ -203,7 +198,7 @@ class Predictor {
|
||||
virtual void PredictContribution(DMatrix* dmat,
|
||||
HostDeviceVector<bst_float>* out_contribs,
|
||||
const gbm::GBTreeModel& model,
|
||||
unsigned ntree_limit = 0,
|
||||
unsigned tree_end = 0,
|
||||
std::vector<bst_float>* tree_weights = nullptr,
|
||||
bool approximate = false,
|
||||
int condition = 0,
|
||||
@ -212,7 +207,7 @@ class Predictor {
|
||||
virtual void PredictInteractionContributions(DMatrix* dmat,
|
||||
HostDeviceVector<bst_float>* out_contribs,
|
||||
const gbm::GBTreeModel& model,
|
||||
unsigned ntree_limit = 0,
|
||||
unsigned tree_end = 0,
|
||||
std::vector<bst_float>* tree_weights = nullptr,
|
||||
bool approximate = false) const = 0;
|
||||
|
||||
|
||||
@ -96,6 +96,24 @@ def from_cstr_to_pystr(data, length):
|
||||
return res
|
||||
|
||||
|
||||
def _convert_ntree_limit(booster, ntree_limit, iteration_range):
|
||||
if ntree_limit is not None and ntree_limit != 0:
|
||||
warnings.warn(
|
||||
"ntree_limit is deprecated, use `iteration_range` or model "
|
||||
"slicing instead.",
|
||||
UserWarning
|
||||
)
|
||||
if iteration_range is not None and iteration_range[1] != 0:
|
||||
raise ValueError(
|
||||
"Only one of `iteration_range` and `ntree_limit` can be non zero."
|
||||
)
|
||||
num_parallel_tree, num_groups = _get_booster_layer_trees(booster)
|
||||
num_parallel_tree = max([num_parallel_tree, 1])
|
||||
num_groups = max([num_groups, 1])
|
||||
iteration_range = (0, ntree_limit // num_parallel_tree)
|
||||
return iteration_range
|
||||
|
||||
|
||||
def _expect(expectations, got):
|
||||
"""Translate input error into string.
|
||||
|
||||
@ -1111,6 +1129,34 @@ Objective = Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]]
|
||||
Metric = Callable[[np.ndarray, DMatrix], Tuple[str, float]]
|
||||
|
||||
|
||||
def _get_booster_layer_trees(model: "Booster") -> Tuple[int, int]:
|
||||
"""Get number of trees added to booster per-iteration. This function will be removed
|
||||
once `best_ntree_limit` is dropped in favor of `best_iteration`. Returns
|
||||
`num_parallel_tree` and `num_groups`.
|
||||
|
||||
"""
|
||||
config = json.loads(model.save_config())
|
||||
booster = config["learner"]["gradient_booster"]["name"]
|
||||
if booster == "gblinear":
|
||||
num_parallel_tree = 0
|
||||
elif booster == "dart":
|
||||
num_parallel_tree = int(
|
||||
config["learner"]["gradient_booster"]["gbtree"]["gbtree_train_param"][
|
||||
"num_parallel_tree"
|
||||
]
|
||||
)
|
||||
elif booster == "gbtree":
|
||||
num_parallel_tree = int(
|
||||
config["learner"]["gradient_booster"]["gbtree_train_param"][
|
||||
"num_parallel_tree"
|
||||
]
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown booster: {booster}")
|
||||
num_groups = int(config["learner"]["learner_model_param"]["num_class"])
|
||||
return num_parallel_tree, num_groups
|
||||
|
||||
|
||||
class Booster(object):
|
||||
# pylint: disable=too-many-public-methods
|
||||
"""A Booster of XGBoost.
|
||||
@ -1497,16 +1543,20 @@ class Booster(object):
|
||||
return self.eval_set([(data, name)], iteration)
|
||||
|
||||
# pylint: disable=too-many-function-args
|
||||
def predict(self,
|
||||
data,
|
||||
output_margin=False,
|
||||
ntree_limit=0,
|
||||
pred_leaf=False,
|
||||
pred_contribs=False,
|
||||
approx_contribs=False,
|
||||
pred_interactions=False,
|
||||
validate_features=True,
|
||||
training=False):
|
||||
def predict(
|
||||
self,
|
||||
data: DMatrix,
|
||||
output_margin: bool = False,
|
||||
ntree_limit: int = 0,
|
||||
pred_leaf: bool = False,
|
||||
pred_contribs: bool = False,
|
||||
approx_contribs: bool = False,
|
||||
pred_interactions: bool = False,
|
||||
validate_features: bool = True,
|
||||
training: bool = False,
|
||||
iteration_range: Tuple[int, int] = (0, 0),
|
||||
strict_shape: bool = False,
|
||||
) -> np.ndarray:
|
||||
"""Predict with data.
|
||||
|
||||
.. note:: This function is not thread safe except for ``gbtree`` booster.
|
||||
@ -1518,33 +1568,32 @@ class Booster(object):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : DMatrix
|
||||
data :
|
||||
The dmatrix storing the input.
|
||||
|
||||
output_margin : bool
|
||||
output_margin :
|
||||
Whether to output the raw untransformed margin value.
|
||||
|
||||
ntree_limit : int
|
||||
Limit number of trees in the prediction; defaults to 0 (use all
|
||||
trees).
|
||||
ntree_limit :
|
||||
Deprecated, use `iteration_range` instead.
|
||||
|
||||
pred_leaf : bool
|
||||
pred_leaf :
|
||||
When this option is on, the output will be a matrix of (nsample,
|
||||
ntrees) with each record indicating the predicted leaf index of
|
||||
each sample in each tree. Note that the leaf index of a tree is
|
||||
unique per tree, so you may find leaf 1 in both tree 1 and tree 0.
|
||||
|
||||
pred_contribs : bool
|
||||
pred_contribs :
|
||||
When this is True the output will be a matrix of size (nsample,
|
||||
nfeats + 1) with each record indicating the feature contributions
|
||||
(SHAP values) for that prediction. The sum of all feature
|
||||
contributions is equal to the raw untransformed margin value of the
|
||||
prediction. Note the final column is the bias term.
|
||||
|
||||
approx_contribs : bool
|
||||
approx_contribs :
|
||||
Approximate the contributions of each feature
|
||||
|
||||
pred_interactions : bool
|
||||
pred_interactions :
|
||||
When this is True the output will be a matrix of size (nsample,
|
||||
nfeats + 1, nfeats + 1) indicating the SHAP interaction values for
|
||||
each pair of features. The sum of each row (or column) of the
|
||||
@ -1553,17 +1602,33 @@ class Booster(object):
|
||||
untransformed margin value of the prediction. Note the last row and
|
||||
column correspond to the bias term.
|
||||
|
||||
validate_features : bool
|
||||
validate_features :
|
||||
When this is True, validate that the Booster's and data's
|
||||
feature_names are identical. Otherwise, it is assumed that the
|
||||
feature_names are the same.
|
||||
|
||||
training : bool
|
||||
training :
|
||||
Whether the prediction value is used for training. This can effect
|
||||
`dart` booster, which performs dropouts during training iterations.
|
||||
|
||||
.. versionadded:: 1.0.0
|
||||
|
||||
iteration_range :
|
||||
Specifies which layer of trees are used in prediction. For example, if a
|
||||
random forest is trained with 100 rounds. Specifying `iteration_range=(10,
|
||||
20)`, then only the forests built during [10, 20) (half open set) rounds are
|
||||
used in this prediction.
|
||||
|
||||
.. versionadded:: 1.4.0
|
||||
|
||||
strict_shape :
|
||||
When set to True, output shape is invariant to whether classification is used.
|
||||
For both value and margin prediction, the output shape is (n_samples,
|
||||
n_groups), n_groups == 1 when multi-class is not used. Default to False, in
|
||||
which case the output shape can be (n_samples, ) if multi-class is not used.
|
||||
|
||||
.. versionadded:: 1.4.0
|
||||
|
||||
.. note:: Using ``predict()`` with DART booster
|
||||
|
||||
If the booster object is DART type, ``predict()`` will not perform
|
||||
@ -1575,64 +1640,50 @@ class Booster(object):
|
||||
prediction : numpy array
|
||||
|
||||
"""
|
||||
option_mask = 0x00
|
||||
if output_margin:
|
||||
option_mask |= 0x01
|
||||
if pred_leaf:
|
||||
option_mask |= 0x02
|
||||
if pred_contribs:
|
||||
option_mask |= 0x04
|
||||
if approx_contribs:
|
||||
option_mask |= 0x08
|
||||
if pred_interactions:
|
||||
option_mask |= 0x10
|
||||
|
||||
if not isinstance(data, DMatrix):
|
||||
raise TypeError('Expecting data to be a DMatrix object, got: ',
|
||||
type(data))
|
||||
|
||||
raise TypeError('Expecting data to be a DMatrix object, got: ', type(data))
|
||||
if validate_features:
|
||||
self._validate_features(data)
|
||||
iteration_range = _convert_ntree_limit(self, ntree_limit, iteration_range)
|
||||
args = {
|
||||
"type": 0,
|
||||
"training": training,
|
||||
"iteration_begin": iteration_range[0],
|
||||
"iteration_end": iteration_range[1],
|
||||
"strict_shape": strict_shape,
|
||||
}
|
||||
|
||||
length = c_bst_ulong()
|
||||
preds = ctypes.POINTER(ctypes.c_float)()
|
||||
_check_call(_LIB.XGBoosterPredict(self.handle, data.handle,
|
||||
ctypes.c_int(option_mask),
|
||||
ctypes.c_uint(ntree_limit),
|
||||
ctypes.c_int(training),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(preds)))
|
||||
preds = ctypes2numpy(preds, length.value, np.float32)
|
||||
def assign_type(t: int) -> None:
|
||||
if args["type"] != 0:
|
||||
raise ValueError("One type of prediction at a time.")
|
||||
args["type"] = t
|
||||
|
||||
if output_margin:
|
||||
assign_type(1)
|
||||
if pred_contribs:
|
||||
assign_type(2 if not approx_contribs else 3)
|
||||
if pred_interactions:
|
||||
assign_type(4)
|
||||
if pred_leaf:
|
||||
preds = preds.astype(np.int32, copy=False)
|
||||
nrow = data.num_row()
|
||||
if preds.size != nrow and preds.size % nrow == 0:
|
||||
chunk_size = int(preds.size / nrow)
|
||||
|
||||
if pred_interactions:
|
||||
ngroup = int(chunk_size / ((data.num_col() + 1) *
|
||||
(data.num_col() + 1)))
|
||||
if ngroup == 1:
|
||||
preds = preds.reshape(nrow,
|
||||
data.num_col() + 1,
|
||||
data.num_col() + 1)
|
||||
else:
|
||||
preds = preds.reshape(nrow, ngroup,
|
||||
data.num_col() + 1,
|
||||
data.num_col() + 1)
|
||||
elif pred_contribs:
|
||||
ngroup = int(chunk_size / (data.num_col() + 1))
|
||||
if ngroup == 1:
|
||||
preds = preds.reshape(nrow, data.num_col() + 1)
|
||||
else:
|
||||
preds = preds.reshape(nrow, ngroup, data.num_col() + 1)
|
||||
else:
|
||||
preds = preds.reshape(nrow, chunk_size)
|
||||
return preds
|
||||
assign_type(5)
|
||||
preds = ctypes.POINTER(ctypes.c_float)()
|
||||
shape = ctypes.POINTER(c_bst_ulong)()
|
||||
dims = c_bst_ulong()
|
||||
_check_call(
|
||||
_LIB.XGBoosterPredictFromDMatrix(
|
||||
self.handle,
|
||||
data.handle,
|
||||
from_pystr_to_cstr(json.dumps(args)),
|
||||
ctypes.byref(shape),
|
||||
ctypes.byref(dims),
|
||||
ctypes.byref(preds)
|
||||
)
|
||||
)
|
||||
return _prediction_output(shape, dims, preds, False)
|
||||
|
||||
def inplace_predict(
|
||||
self,
|
||||
data,
|
||||
data: Any,
|
||||
iteration_range: Tuple[int, int] = (0, 0),
|
||||
predict_type: str = "value",
|
||||
missing: float = np.nan,
|
||||
@ -1665,26 +1716,24 @@ class Booster(object):
|
||||
The input data, must not be a view for numpy array. Set
|
||||
``predictor`` to ``gpu_predictor`` for running prediction on CuPy
|
||||
array or CuDF DataFrame.
|
||||
iteration_range : tuple
|
||||
Specifies which layer of trees are used in prediction. For
|
||||
example, if a random forest is trained with 100 rounds. Specifying
|
||||
`iteration_range=(10, 20)`, then only the forests built during [10,
|
||||
20) (open set) rounds are used in this prediction.
|
||||
predict_type : str
|
||||
iteration_range :
|
||||
See :py:meth:`xgboost.Booster.predict` for details.
|
||||
predict_type :
|
||||
* `value` Output model prediction values.
|
||||
* `margin` Output the raw untransformed margin value.
|
||||
missing : float
|
||||
Value in the input data which needs to be present as a missing
|
||||
value.
|
||||
missing :
|
||||
See :py:obj:`xgboost.DMatrix` for details.
|
||||
validate_features:
|
||||
See :py:meth:`xgboost.Booster.predict` for details.
|
||||
base_margin:
|
||||
See :py:obj:`xgboost.DMatrix` for details.
|
||||
|
||||
.. versionadded:: 1.4.0
|
||||
|
||||
strict_shape:
|
||||
When set to True, output shape is invariant to whether classification is used.
|
||||
For both value and margin prediction, the output shape is (n_samples,
|
||||
n_groups), n_groups == 1 when multi-class is not used. Default to False, in
|
||||
which case the output shape can be (n_samples, ) if multi-class is not used.
|
||||
See :py:meth:`xgboost.Booster.predict` for details.
|
||||
|
||||
.. versionadded:: 1.4.0
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -1772,7 +1821,7 @@ class Booster(object):
|
||||
interface["mask"] = interface["mask"].__cuda_array_interface__
|
||||
interface_str = bytes(json.dumps(interface, indent=2), "utf-8")
|
||||
_check_call(
|
||||
_LIB.XGBoosterPredictFromArrayInterface(
|
||||
_LIB.XGBoosterPredictFromCudaArray(
|
||||
self.handle,
|
||||
interface_str,
|
||||
from_pystr_to_cstr(json.dumps(args)),
|
||||
@ -1788,7 +1837,7 @@ class Booster(object):
|
||||
|
||||
interfaces_str = _cudf_array_interfaces(data)
|
||||
_check_call(
|
||||
_LIB.XGBoosterPredictFromArrayInterfaceColumns(
|
||||
_LIB.XGBoosterPredictFromCudaColumnar(
|
||||
self.handle,
|
||||
interfaces_str,
|
||||
from_pystr_to_cstr(json.dumps(args)),
|
||||
|
||||
@ -254,6 +254,7 @@ class DaskDMatrix:
|
||||
raise TypeError(_expect((dd.DataFrame, da.Array, dd.Series), type(label)))
|
||||
|
||||
self._n_cols = data.shape[1]
|
||||
assert isinstance(self._n_cols, int)
|
||||
self.worker_map: Dict[str, "distributed.Future"] = defaultdict(list)
|
||||
self.is_quantile: bool = False
|
||||
|
||||
@ -881,7 +882,7 @@ async def _train_async(
|
||||
return list(filter(lambda ret: ret is not None, results))[0]
|
||||
|
||||
|
||||
def train(
|
||||
def train( # pylint: disable=unused-argument
|
||||
client: "distributed.Client",
|
||||
params: Dict[str, Any],
|
||||
dtrain: DaskDMatrix,
|
||||
@ -892,16 +893,17 @@ def train(
|
||||
early_stopping_rounds: Optional[int] = None,
|
||||
xgb_model: Optional[Booster] = None,
|
||||
verbose_eval: Union[int, bool] = True,
|
||||
callbacks: Optional[List[TrainingCallback]] = None
|
||||
callbacks: Optional[List[TrainingCallback]] = None,
|
||||
) -> Any:
|
||||
'''Train XGBoost model.
|
||||
"""Train XGBoost model.
|
||||
|
||||
.. versionadded:: 1.0.0
|
||||
|
||||
.. note::
|
||||
|
||||
Other parameters are the same as `xgboost.train` except for `evals_result`, which
|
||||
is returned as part of function return value instead of argument.
|
||||
Other parameters are the same as :py:func:`xgboost.train` except for
|
||||
`evals_result`, which is returned as part of function return value instead of
|
||||
argument.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -920,29 +922,17 @@ def train(
|
||||
{'booster': xgboost.Booster,
|
||||
'history': {'train': {'logloss': ['0.48253', '0.35953']},
|
||||
'eval': {'logloss': ['0.480385', '0.357756']}}}
|
||||
'''
|
||||
|
||||
"""
|
||||
_assert_dask_support()
|
||||
client = _xgb_get_client(client)
|
||||
# Get global configuration before transferring computation to another thread or
|
||||
# process.
|
||||
global_config = config.get_config()
|
||||
return client.sync(_train_async,
|
||||
client=client,
|
||||
global_config=global_config,
|
||||
num_boost_round=num_boost_round,
|
||||
obj=obj,
|
||||
feval=feval,
|
||||
params=params,
|
||||
dtrain=dtrain,
|
||||
evals=evals,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose_eval=verbose_eval,
|
||||
xgb_model=xgb_model,
|
||||
callbacks=callbacks)
|
||||
return client.sync(_train_async, global_config=config.get_config(), **locals())
|
||||
|
||||
|
||||
def _can_output_df(data: _DaskCollection, output_shape: Tuple) -> bool:
|
||||
return isinstance(data, dd.DataFrame) and len(output_shape) <= 2
|
||||
def _can_output_df(is_df: bool, output_shape: Tuple) -> bool:
|
||||
return is_df and len(output_shape) <= 2
|
||||
|
||||
|
||||
async def _direct_predict_impl(
|
||||
@ -954,8 +944,9 @@ async def _direct_predict_impl(
|
||||
meta: Dict[int, str],
|
||||
) -> _DaskCollection:
|
||||
columns = list(meta.keys())
|
||||
if _can_output_df(data, output_shape):
|
||||
if _can_output_df(isinstance(data, dd.DataFrame), output_shape):
|
||||
if base_margin is not None and isinstance(base_margin, da.Array):
|
||||
# Easier for map_partitions
|
||||
base_margin_df: Optional[dd.DataFrame] = base_margin.to_dask_dataframe()
|
||||
else:
|
||||
base_margin_df = base_margin
|
||||
@ -975,17 +966,21 @@ async def _direct_predict_impl(
|
||||
if base_margin is not None and isinstance(
|
||||
base_margin, (dd.Series, dd.DataFrame)
|
||||
):
|
||||
# Easier for map_blocks
|
||||
base_margin_array: Optional[da.Array] = base_margin.to_dask_array()
|
||||
else:
|
||||
base_margin_array = base_margin
|
||||
# Input data is 2-dim array, output can be 1(reg, binary)/2(multi-class,
|
||||
# contrib)/3(contrib)/4(interaction) dims.
|
||||
# contrib)/3(contrib, interaction)/4(interaction) dims.
|
||||
if len(output_shape) == 1:
|
||||
drop_axis: Union[int, List[int]] = [1] # drop from 2 to 1 dim.
|
||||
new_axis: Union[int, List[int]] = []
|
||||
else:
|
||||
drop_axis = []
|
||||
new_axis = [i + 2 for i in range(len(output_shape) - 2)]
|
||||
if isinstance(data, dd.DataFrame):
|
||||
new_axis = list(range(len(output_shape) - 2))
|
||||
else:
|
||||
new_axis = [i + 2 for i in range(len(output_shape) - 2)]
|
||||
predictions = da.map_blocks(
|
||||
mapped_predict,
|
||||
booster,
|
||||
@ -1001,28 +996,21 @@ async def _direct_predict_impl(
|
||||
|
||||
|
||||
def _infer_predict_output(
|
||||
booster: Booster,
|
||||
data: Union[DaskDMatrix, _DaskCollection],
|
||||
inplace: bool,
|
||||
**kwargs: Any
|
||||
booster: Booster, features: int, is_df: bool, inplace: bool, **kwargs: Any
|
||||
) -> Tuple[Tuple[int, ...], Dict[int, str]]:
|
||||
"""Create a dummy test sample to infer output shape for prediction."""
|
||||
if isinstance(data, DaskDMatrix):
|
||||
features = data.num_col()
|
||||
else:
|
||||
features = data.shape[1]
|
||||
assert isinstance(features, int)
|
||||
rng = numpy.random.RandomState(1994)
|
||||
test_sample = rng.randn(1, features)
|
||||
if inplace:
|
||||
# clear the state to avoid gpu_id, gpu_predictor
|
||||
booster = Booster(model_file=booster.save_raw())
|
||||
test_predt = booster.inplace_predict(test_sample, **kwargs)
|
||||
else:
|
||||
m = DMatrix(test_sample)
|
||||
test_predt = booster.predict(m, **kwargs)
|
||||
kwargs = kwargs.copy()
|
||||
if kwargs.pop("predict_type") == "margin":
|
||||
kwargs["output_margin"] = True
|
||||
m = DMatrix(test_sample)
|
||||
test_predt = booster.predict(m, validate_features=False, **kwargs)
|
||||
n_columns = test_predt.shape[1] if len(test_predt.shape) > 1 else 1
|
||||
meta: Dict[int, str] = {}
|
||||
if _can_output_df(data, test_predt.shape):
|
||||
if _can_output_df(is_df, test_predt.shape):
|
||||
for i in range(n_columns):
|
||||
meta[i] = "f4"
|
||||
return test_predt.shape, meta
|
||||
@ -1034,7 +1022,7 @@ async def _get_model_future(
|
||||
if isinstance(model, Booster):
|
||||
booster = await client.scatter(model, broadcast=True)
|
||||
elif isinstance(model, dict):
|
||||
booster = await client.scatter(model["booster"])
|
||||
booster = await client.scatter(model["booster"], broadcast=True)
|
||||
elif isinstance(model, distributed.Future):
|
||||
booster = model
|
||||
if booster.type is not Booster:
|
||||
@ -1059,6 +1047,8 @@ async def _predict_async(
|
||||
approx_contribs: bool,
|
||||
pred_interactions: bool,
|
||||
validate_features: bool,
|
||||
iteration_range: Tuple[int, int],
|
||||
strict_shape: bool,
|
||||
) -> _DaskCollection:
|
||||
_booster = await _get_model_future(client, model)
|
||||
if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
|
||||
@ -1077,43 +1067,51 @@ async def _predict_async(
|
||||
approx_contribs=approx_contribs,
|
||||
pred_interactions=pred_interactions,
|
||||
validate_features=validate_features,
|
||||
iteration_range=iteration_range,
|
||||
strict_shape=strict_shape,
|
||||
)
|
||||
if is_df and len(predt.shape) <= 2:
|
||||
if _can_output_df(is_df, predt.shape):
|
||||
if lazy_isinstance(partition, "cudf", "core.dataframe.DataFrame"):
|
||||
import cudf
|
||||
|
||||
predt = cudf.DataFrame(predt, columns=columns)
|
||||
predt = cudf.DataFrame(predt, columns=columns, dtype=numpy.float32)
|
||||
else:
|
||||
predt = DataFrame(predt, columns=columns)
|
||||
predt = DataFrame(predt, columns=columns, dtype=numpy.float32)
|
||||
return predt
|
||||
|
||||
# Predict on dask collection directly.
|
||||
if isinstance(data, (da.Array, dd.DataFrame)):
|
||||
_output_shape, meta = _infer_predict_output(
|
||||
await _booster.result(),
|
||||
data,
|
||||
_output_shape, meta = await client.compute(
|
||||
client.submit(
|
||||
_infer_predict_output,
|
||||
_booster,
|
||||
features=data.shape[1],
|
||||
is_df=isinstance(data, dd.DataFrame),
|
||||
inplace=False,
|
||||
output_margin=output_margin,
|
||||
pred_leaf=pred_leaf,
|
||||
pred_contribs=pred_contribs,
|
||||
approx_contribs=approx_contribs,
|
||||
pred_interactions=pred_interactions,
|
||||
)
|
||||
)
|
||||
return await _direct_predict_impl(
|
||||
mapped_predict, _booster, data, None, _output_shape, meta
|
||||
)
|
||||
|
||||
output_shape, _ = await client.compute(
|
||||
client.submit(
|
||||
_infer_predict_output,
|
||||
booster=_booster,
|
||||
features=data.num_col(),
|
||||
is_df=False,
|
||||
inplace=False,
|
||||
output_margin=output_margin,
|
||||
pred_leaf=pred_leaf,
|
||||
pred_contribs=pred_contribs,
|
||||
approx_contribs=approx_contribs,
|
||||
pred_interactions=pred_interactions,
|
||||
validate_features=False,
|
||||
)
|
||||
return await _direct_predict_impl(
|
||||
mapped_predict, _booster, data, None, _output_shape, meta
|
||||
)
|
||||
|
||||
output_shape, _ = _infer_predict_output(
|
||||
booster=await _booster.result(),
|
||||
data=data,
|
||||
inplace=False,
|
||||
output_margin=output_margin,
|
||||
pred_leaf=pred_leaf,
|
||||
pred_contribs=pred_contribs,
|
||||
approx_contribs=approx_contribs,
|
||||
pred_interactions=pred_interactions,
|
||||
validate_features=False,
|
||||
)
|
||||
# Prediction on dask DMatrix.
|
||||
partition_order = data.partition_order
|
||||
@ -1180,7 +1178,7 @@ async def _predict_async(
|
||||
futures[i], shape=(rows,) + output_shape[1:], dtype=numpy.float32
|
||||
)
|
||||
)
|
||||
predictions = await da.concatenate(arrays, axis=0)
|
||||
predictions = da.concatenate(arrays, axis=0)
|
||||
return predictions
|
||||
|
||||
|
||||
@ -1194,15 +1192,19 @@ def predict( # pylint: disable=unused-argument
|
||||
pred_contribs: bool = False,
|
||||
approx_contribs: bool = False,
|
||||
pred_interactions: bool = False,
|
||||
validate_features: bool = True
|
||||
validate_features: bool = True,
|
||||
iteration_range: Tuple[int, int] = (0, 0),
|
||||
strict_shape: bool = False,
|
||||
) -> Any:
|
||||
'''Run prediction with a trained booster.
|
||||
|
||||
.. note::
|
||||
|
||||
Using ``inplace_predict `` might be faster when meta information like
|
||||
``base_margin`` is not needed. For other parameters, please see
|
||||
``Booster.predict``.
|
||||
Using ``inplace_predict`` might be faster when some features are not needed. See
|
||||
:py:meth:`xgboost.Booster.predict` for details on various parameters. When using
|
||||
``pred_interactions`` with mutli-class model, input should be ``da.Array`` or
|
||||
``DaskDMatrix`` due to limitation in ``da.map_blocks``.
|
||||
|
||||
|
||||
.. versionadded:: 1.0.0
|
||||
|
||||
@ -1232,69 +1234,83 @@ def predict( # pylint: disable=unused-argument
|
||||
'''
|
||||
_assert_dask_support()
|
||||
client = _xgb_get_client(client)
|
||||
return client.sync(
|
||||
_predict_async, global_config=config.get_config(), **locals()
|
||||
)
|
||||
return client.sync(_predict_async, global_config=config.get_config(), **locals())
|
||||
|
||||
|
||||
async def _inplace_predict_async(
|
||||
async def _inplace_predict_async( # pylint: disable=too-many-branches
|
||||
client: "distributed.Client",
|
||||
global_config: Dict[str, Any],
|
||||
model: Union[Booster, Dict, "distributed.Future"],
|
||||
data: _DaskCollection,
|
||||
iteration_range: Tuple[int, int] = (0, 0),
|
||||
predict_type: str = 'value',
|
||||
missing: float = numpy.nan
|
||||
iteration_range: Tuple[int, int],
|
||||
predict_type: str,
|
||||
missing: float,
|
||||
validate_features: bool,
|
||||
base_margin: Optional[_DaskCollection],
|
||||
strict_shape: bool,
|
||||
) -> _DaskCollection:
|
||||
client = _xgb_get_client(client)
|
||||
booster = await _get_model_future(client, model)
|
||||
if not isinstance(data, (da.Array, dd.DataFrame)):
|
||||
raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))
|
||||
if base_margin is not None and not isinstance(
|
||||
data, (da.Array, dd.DataFrame, dd.Series)
|
||||
):
|
||||
raise TypeError(_expect([da.Array, dd.DataFrame, dd.Series], type(base_margin)))
|
||||
|
||||
def mapped_predict(
|
||||
booster: Booster, data: Any, is_df: bool, columns: List[int], _: Any
|
||||
booster: Booster, data: Any, is_df: bool, columns: List[int], base_margin: Any
|
||||
) -> Any:
|
||||
with config.config_context(**global_config):
|
||||
prediction = booster.inplace_predict(
|
||||
data,
|
||||
iteration_range=iteration_range,
|
||||
predict_type=predict_type,
|
||||
missing=missing
|
||||
missing=missing,
|
||||
base_margin=base_margin,
|
||||
validate_features=validate_features,
|
||||
strict_shape=strict_shape,
|
||||
)
|
||||
if is_df and len(prediction.shape) <= 2:
|
||||
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
|
||||
if _can_output_df(is_df, prediction.shape):
|
||||
if lazy_isinstance(data, "cudf.core.dataframe", "DataFrame"):
|
||||
import cudf
|
||||
|
||||
prediction = cudf.DataFrame(
|
||||
prediction, columns=columns, dtype=numpy.float32
|
||||
)
|
||||
else:
|
||||
# If it's from pandas, the partition is a numpy array
|
||||
prediction = DataFrame(
|
||||
prediction, columns=columns, dtype=numpy.float32
|
||||
)
|
||||
# If it's from pandas, the partition is a numpy array
|
||||
prediction = DataFrame(prediction, columns=columns, dtype=numpy.float32)
|
||||
return prediction
|
||||
|
||||
shape, meta = _infer_predict_output(
|
||||
await booster.result(),
|
||||
data,
|
||||
True,
|
||||
predict_type=predict_type,
|
||||
iteration_range=iteration_range
|
||||
# await turns future into value.
|
||||
shape, meta = await client.compute(
|
||||
client.submit(
|
||||
_infer_predict_output,
|
||||
booster,
|
||||
features=data.shape[1],
|
||||
is_df=isinstance(data, dd.DataFrame),
|
||||
inplace=True,
|
||||
predict_type=predict_type,
|
||||
iteration_range=iteration_range,
|
||||
)
|
||||
)
|
||||
return await _direct_predict_impl(
|
||||
mapped_predict, booster, data, None, shape, meta
|
||||
mapped_predict, booster, data, base_margin, shape, meta
|
||||
)
|
||||
|
||||
|
||||
def inplace_predict( # pylint: disable=unused-argument
|
||||
def inplace_predict( # pylint: disable=unused-argument
|
||||
client: "distributed.Client",
|
||||
model: Union[TrainReturnT, Booster, "distributed.Future"],
|
||||
data: _DaskCollection,
|
||||
iteration_range: Tuple[int, int] = (0, 0),
|
||||
predict_type: str = 'value',
|
||||
missing: float = numpy.nan
|
||||
predict_type: str = "value",
|
||||
missing: float = numpy.nan,
|
||||
validate_features: bool = True,
|
||||
base_margin: Optional[_DaskCollection] = None,
|
||||
strict_shape: bool = False,
|
||||
) -> Any:
|
||||
'''Inplace prediction.
|
||||
"""Inplace prediction. See doc in :py:meth:`xgboost.Booster.inplace_predict` for details.
|
||||
|
||||
.. versionadded:: 1.1.0
|
||||
|
||||
@ -1304,16 +1320,27 @@ def inplace_predict( # pylint: disable=unused-argument
|
||||
Specify the dask client used for training. Use default client
|
||||
returned from dask if it's set to None.
|
||||
model:
|
||||
The trained model. It can be a distributed.Future so user can
|
||||
pre-scatter it onto all workers.
|
||||
See :py:func:`xgboost.dask.predict` for details.
|
||||
data :
|
||||
dask collection.
|
||||
iteration_range:
|
||||
Specify the range of trees used for prediction.
|
||||
See :py:meth:`xgboost.Booster.predict` for details.
|
||||
predict_type:
|
||||
* 'value': Normal prediction result.
|
||||
* 'margin': Output the raw untransformed margin value.
|
||||
See :py:meth:`xgboost.Booster.inplace_predict` for details.
|
||||
missing:
|
||||
Value in the input data which needs to be present as a missing
|
||||
value. If None, defaults to np.nan.
|
||||
base_margin:
|
||||
See :py:obj:`xgboost.DMatrix` for details. Right now classifier is not well
|
||||
supported with base_margin as it requires the size of base margin to be `n_classes
|
||||
* n_samples`.
|
||||
|
||||
.. versionadded:: 1.4.0
|
||||
|
||||
strict_shape:
|
||||
See :py:meth:`xgboost.Booster.predict` for details.
|
||||
|
||||
.. versionadded:: 1.4.0
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -1322,7 +1349,7 @@ def inplace_predict( # pylint: disable=unused-argument
|
||||
data is ``dask.dataframe.DataFrame``, return value can be
|
||||
``dask.dataframe.Series``, ``dask.dataframe.DataFrame`` or ``dask.array.Array``,
|
||||
depending on the output shape.
|
||||
'''
|
||||
"""
|
||||
_assert_dask_support()
|
||||
client = _xgb_get_client(client)
|
||||
return client.sync(
|
||||
@ -1334,9 +1361,11 @@ async def _async_wrap_evaluation_matrices(
|
||||
client: "distributed.Client", **kwargs: Any
|
||||
) -> Tuple[DaskDMatrix, Optional[List[Tuple[DaskDMatrix, str]]]]:
|
||||
"""A switch function for async environment."""
|
||||
|
||||
def _inner(**kwargs: Any) -> DaskDMatrix:
|
||||
m = DaskDMatrix(client=client, **kwargs)
|
||||
return m
|
||||
|
||||
train_dmatrix, evals = _wrap_evaluation_matrices(create_dmatrix=_inner, **kwargs)
|
||||
train_dmatrix = await train_dmatrix
|
||||
if evals is None:
|
||||
@ -1351,25 +1380,45 @@ async def _async_wrap_evaluation_matrices(
|
||||
|
||||
|
||||
class DaskScikitLearnBase(XGBModel):
|
||||
'''Base class for implementing scikit-learn interface with Dask'''
|
||||
"""Base class for implementing scikit-learn interface with Dask"""
|
||||
|
||||
_client = None
|
||||
|
||||
async def _predict_async(
|
||||
self, data: _DaskCollection,
|
||||
output_margin: bool = False,
|
||||
validate_features: bool = True,
|
||||
base_margin: Optional[_DaskCollection] = None
|
||||
self,
|
||||
data: _DaskCollection,
|
||||
output_margin: bool,
|
||||
validate_features: bool,
|
||||
base_margin: Optional[_DaskCollection],
|
||||
iteration_range: Optional[Tuple[int, int]],
|
||||
) -> Any:
|
||||
test_dmatrix = await DaskDMatrix(
|
||||
client=self.client, data=data, base_margin=base_margin,
|
||||
missing=self.missing
|
||||
)
|
||||
pred_probs = await predict(client=self.client,
|
||||
model=self.get_booster(), data=test_dmatrix,
|
||||
output_margin=output_margin,
|
||||
validate_features=validate_features)
|
||||
return pred_probs
|
||||
iteration_range = self._get_iteration_range(iteration_range)
|
||||
if self._can_use_inplace_predict():
|
||||
predts = await inplace_predict(
|
||||
client=self.client,
|
||||
model=self.get_booster(),
|
||||
data=data,
|
||||
iteration_range=iteration_range,
|
||||
predict_type="margin" if output_margin else "value",
|
||||
missing=self.missing,
|
||||
base_margin=base_margin,
|
||||
validate_features=validate_features,
|
||||
)
|
||||
if isinstance(predts, dd.DataFrame):
|
||||
predts = predts.to_dask_array()
|
||||
else:
|
||||
test_dmatrix = await DaskDMatrix(
|
||||
self.client, data=data, base_margin=base_margin, missing=self.missing
|
||||
)
|
||||
predts = await predict(
|
||||
self.client,
|
||||
model=self.get_booster(),
|
||||
data=test_dmatrix,
|
||||
output_margin=output_margin,
|
||||
validate_features=validate_features,
|
||||
iteration_range=iteration_range,
|
||||
)
|
||||
return predts
|
||||
|
||||
def predict(
|
||||
self,
|
||||
@ -1377,26 +1426,56 @@ class DaskScikitLearnBase(XGBModel):
|
||||
output_margin: bool = False,
|
||||
ntree_limit: Optional[int] = None,
|
||||
validate_features: bool = True,
|
||||
base_margin: Optional[_DaskCollection] = None
|
||||
base_margin: Optional[_DaskCollection] = None,
|
||||
iteration_range: Optional[Tuple[int, int]] = None,
|
||||
) -> Any:
|
||||
_assert_dask_support()
|
||||
msg = '`ntree_limit` is not supported on dask, use model slicing instead.'
|
||||
msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead."
|
||||
assert ntree_limit is None, msg
|
||||
return self.client.sync(
|
||||
self._predict_async,
|
||||
X,
|
||||
output_margin=output_margin,
|
||||
validate_features=validate_features,
|
||||
base_margin=base_margin
|
||||
base_margin=base_margin,
|
||||
iteration_range=iteration_range,
|
||||
)
|
||||
|
||||
async def _apply_async(
|
||||
self,
|
||||
X: _DaskCollection,
|
||||
iteration_range: Optional[Tuple[int, int]] = None,
|
||||
) -> Any:
|
||||
iteration_range = self._get_iteration_range(iteration_range)
|
||||
test_dmatrix = await DaskDMatrix(self.client, data=X, missing=self.missing)
|
||||
predts = await predict(
|
||||
self.client,
|
||||
model=self.get_booster(),
|
||||
data=test_dmatrix,
|
||||
pred_leaf=True,
|
||||
iteration_range=iteration_range,
|
||||
)
|
||||
return predts
|
||||
|
||||
def apply(
|
||||
self,
|
||||
X: _DaskCollection,
|
||||
ntree_limit: Optional[int] = None,
|
||||
iteration_range: Optional[Tuple[int, int]] = None,
|
||||
) -> Any:
|
||||
_assert_dask_support()
|
||||
msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead."
|
||||
assert ntree_limit is None, msg
|
||||
return self.client.sync(self._apply_async, X, iteration_range=iteration_range)
|
||||
|
||||
def __await__(self) -> Awaitable[Any]:
|
||||
# Generate a coroutine wrapper to make this class awaitable.
|
||||
async def _() -> Awaitable[Any]:
|
||||
return self
|
||||
|
||||
return self.client.sync(_).__await__()
|
||||
|
||||
def __getstate__(self):
|
||||
def __getstate__(self) -> Dict:
|
||||
this = self.__dict__.copy()
|
||||
if "_client" in this.keys():
|
||||
del this["_client"]
|
||||
@ -1404,7 +1483,7 @@ class DaskScikitLearnBase(XGBModel):
|
||||
|
||||
@property
|
||||
def client(self) -> "distributed.Client":
|
||||
'''The dask client used in this model.'''
|
||||
"""The dask client used in this model."""
|
||||
client = _xgb_get_client(self._client)
|
||||
return client
|
||||
|
||||
@ -1494,7 +1573,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
|
||||
base_margin_eval_set: Optional[List[_DaskCollection]] = None,
|
||||
feature_weights: Optional[_DaskCollection] = None,
|
||||
callbacks: Optional[List[TrainingCallback]] = None
|
||||
callbacks: Optional[List[TrainingCallback]] = None,
|
||||
) -> "DaskXGBRegressor":
|
||||
_assert_dask_support()
|
||||
args = {k: v for k, v in locals().items() if k != "self"}
|
||||
@ -1556,9 +1635,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
else:
|
||||
obj = None
|
||||
model, metric, params = self._configure_fit(
|
||||
booster=xgb_model,
|
||||
eval_metric=eval_metric,
|
||||
params=params
|
||||
booster=xgb_model, eval_metric=eval_metric, params=params
|
||||
)
|
||||
results = await train(
|
||||
client=self.client,
|
||||
@ -1610,18 +1687,19 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
X: _DaskCollection,
|
||||
validate_features: bool,
|
||||
output_margin: bool,
|
||||
base_margin: Optional[_DaskCollection]
|
||||
base_margin: Optional[_DaskCollection],
|
||||
iteration_range: Optional[Tuple[int, int]],
|
||||
) -> _DaskCollection:
|
||||
test_dmatrix = await DaskDMatrix(
|
||||
client=self.client, data=X, base_margin=base_margin,
|
||||
missing=self.missing
|
||||
if iteration_range is None:
|
||||
iteration_range = (0, 0)
|
||||
predts = await super()._predict_async(
|
||||
data=X,
|
||||
output_margin=output_margin,
|
||||
validate_features=validate_features,
|
||||
base_margin=base_margin,
|
||||
iteration_range=iteration_range,
|
||||
)
|
||||
pred_probs = await predict(client=self.client,
|
||||
model=self.get_booster(),
|
||||
data=test_dmatrix,
|
||||
validate_features=validate_features,
|
||||
output_margin=output_margin)
|
||||
return _cls_predict_proba(self.objective, pred_probs, da.vstack)
|
||||
return _cls_predict_proba(self.objective, predts, da.vstack)
|
||||
|
||||
# pylint: disable=missing-function-docstring
|
||||
def predict_proba(
|
||||
@ -1630,37 +1708,49 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
ntree_limit: Optional[int] = None,
|
||||
validate_features: bool = True,
|
||||
output_margin: bool = False,
|
||||
base_margin: Optional[_DaskCollection] = None
|
||||
base_margin: Optional[_DaskCollection] = None,
|
||||
iteration_range: Optional[Tuple[int, int]] = None,
|
||||
) -> Any:
|
||||
_assert_dask_support()
|
||||
msg = '`ntree_limit` is not supported on dask, use model slicing instead.'
|
||||
msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead."
|
||||
assert ntree_limit is None, msg
|
||||
return self.client.sync(
|
||||
self._predict_proba_async,
|
||||
X=X,
|
||||
validate_features=validate_features,
|
||||
output_margin=output_margin,
|
||||
base_margin=base_margin
|
||||
base_margin=base_margin,
|
||||
iteration_range=iteration_range,
|
||||
)
|
||||
|
||||
predict_proba.__doc__ = XGBClassifier.predict_proba.__doc__
|
||||
|
||||
async def _predict_async(
|
||||
self, data: _DaskCollection,
|
||||
output_margin: bool = False,
|
||||
validate_features: bool = True,
|
||||
base_margin: Optional[_DaskCollection] = None
|
||||
self,
|
||||
data: _DaskCollection,
|
||||
output_margin: bool,
|
||||
validate_features: bool,
|
||||
base_margin: Optional[_DaskCollection],
|
||||
iteration_range: Optional[Tuple[int, int]],
|
||||
) -> _DaskCollection:
|
||||
pred_probs = await super()._predict_async(
|
||||
data, output_margin, validate_features, base_margin
|
||||
data, output_margin, validate_features, base_margin, iteration_range
|
||||
)
|
||||
if output_margin:
|
||||
return pred_probs
|
||||
|
||||
if self.n_classes_ == 2:
|
||||
if len(pred_probs.shape) == 1:
|
||||
preds = (pred_probs > 0.5).astype(int)
|
||||
else:
|
||||
preds = da.argmax(pred_probs, axis=1)
|
||||
assert len(pred_probs.shape) == 2
|
||||
assert isinstance(pred_probs, da.Array)
|
||||
# when using da.argmax directly, dask will construct a numpy based return
|
||||
# array, which runs into error when computing GPU based prediction.
|
||||
|
||||
def _argmax(x: Any) -> Any:
|
||||
return x.argmax(axis=1)
|
||||
|
||||
preds = da.map_blocks(_argmax, pred_probs, drop_axis=1)
|
||||
return preds
|
||||
|
||||
|
||||
@ -1770,7 +1860,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
||||
callbacks: Optional[List[TrainingCallback]] = None
|
||||
) -> "DaskXGBRanker":
|
||||
_assert_dask_support()
|
||||
args = {k: v for k, v in locals().items() if k != 'self'}
|
||||
args = {k: v for k, v in locals().items() if k != "self"}
|
||||
return self.client.sync(self._fit_async, **args)
|
||||
|
||||
# FIXME(trivialfis): arguments differ due to additional parameters like group and qid.
|
||||
|
||||
@ -6,7 +6,8 @@ import warnings
|
||||
import json
|
||||
from typing import Union, Optional, List, Dict, Callable, Tuple, Any
|
||||
import numpy as np
|
||||
from .core import Booster, DMatrix, XGBoostError, _deprecate_positional_args
|
||||
from .core import Booster, DMatrix, XGBoostError
|
||||
from .core import _deprecate_positional_args, _convert_ntree_limit
|
||||
from .core import Metric
|
||||
from .training import train
|
||||
from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array
|
||||
@ -413,8 +414,8 @@ class XGBModel(XGBModelBase):
|
||||
# Simple optimization to gain speed (inspect is slow)
|
||||
return self
|
||||
|
||||
# this concatenates kwargs into paraemters, enabling `get_params` for
|
||||
# obtaining parameters from keyword paraemters.
|
||||
# this concatenates kwargs into parameters, enabling `get_params` for
|
||||
# obtaining parameters from keyword parameters.
|
||||
for key, value in params.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
@ -747,26 +748,45 @@ class XGBModel(XGBModelBase):
|
||||
self._set_evaluation_result(evals_result)
|
||||
return self
|
||||
|
||||
def _can_use_inplace_predict(self) -> bool:
|
||||
# When predictor is explicitly set, using `inplace_predict` might result into
|
||||
# error with incompatible data type.
|
||||
# Inplace predict doesn't handle as many data types as DMatrix, but it's
|
||||
# sufficient for dask interface where input is simpiler.
|
||||
params = self.get_params()
|
||||
booster = self.booster
|
||||
if params.get("predictor", None) is None and (
|
||||
booster is None or booster == "gbtree"
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_iteration_range(
|
||||
self, iteration_range: Optional[Tuple[int, int]]
|
||||
) -> Tuple[int, int]:
|
||||
if (iteration_range is None or iteration_range[1] == 0):
|
||||
# Use best_iteration if defined.
|
||||
try:
|
||||
iteration_range = (0, self.best_iteration + 1)
|
||||
except AttributeError:
|
||||
iteration_range = (0, 0)
|
||||
if self.booster == "gblinear":
|
||||
iteration_range = (0, 0)
|
||||
return iteration_range
|
||||
|
||||
def predict(
|
||||
self,
|
||||
X,
|
||||
output_margin=False,
|
||||
ntree_limit=None,
|
||||
validate_features=True,
|
||||
base_margin=None
|
||||
base_margin=None,
|
||||
iteration_range=None,
|
||||
):
|
||||
"""
|
||||
Predict with `X`.
|
||||
|
||||
.. note:: This function is not thread safe.
|
||||
|
||||
For each booster object, predict can only be called from one thread.
|
||||
If you want to run prediction using multiple thread, call ``xgb.copy()`` to make copies
|
||||
of model object and then call ``predict()``.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
preds = bst.predict(dtest, ntree_limit=num_round)
|
||||
.. note:: This function is only thread safe for `gbtree`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -775,37 +795,40 @@ class XGBModel(XGBModelBase):
|
||||
output_margin : bool
|
||||
Whether to output the raw untransformed margin value.
|
||||
ntree_limit : int
|
||||
Limit number of trees in the prediction; defaults to best_ntree_limit if
|
||||
defined (i.e. it has been trained with early stopping), otherwise 0 (use all
|
||||
trees).
|
||||
Deprecated, use `iteration_range` instead.
|
||||
validate_features : bool
|
||||
When this is True, validate that the Booster's and data's feature_names are identical.
|
||||
Otherwise, it is assumed that the feature_names are the same.
|
||||
When this is True, validate that the Booster's and data's feature_names are
|
||||
identical. Otherwise, it is assumed that the feature_names are the same.
|
||||
base_margin : array_like
|
||||
Margin added to prediction.
|
||||
iteration_range :
|
||||
Specifies which layer of trees are used in prediction. For example, if a
|
||||
random forest is trained with 100 rounds. Specifying `iteration_range=(10,
|
||||
20)`, then only the forests built during [10, 20) (half open set) rounds are
|
||||
used in this prediction.
|
||||
|
||||
.. versionadded:: 1.4.0
|
||||
Returns
|
||||
-------
|
||||
prediction : numpy array
|
||||
"""
|
||||
# pylint: disable=missing-docstring,invalid-name
|
||||
test_dmatrix = DMatrix(X, base_margin=base_margin,
|
||||
missing=self.missing, nthread=self.n_jobs)
|
||||
# get ntree_limit to use - if none specified, default to
|
||||
# best_ntree_limit if defined, otherwise 0.
|
||||
if ntree_limit is None:
|
||||
try:
|
||||
ntree_limit = self.best_ntree_limit
|
||||
except AttributeError:
|
||||
ntree_limit = 0
|
||||
iteration_range = _convert_ntree_limit(
|
||||
self.get_booster(), ntree_limit, iteration_range
|
||||
)
|
||||
iteration_range = self._get_iteration_range(iteration_range)
|
||||
test = DMatrix(
|
||||
X, base_margin=base_margin, missing=self.missing, nthread=self.n_jobs
|
||||
)
|
||||
return self.get_booster().predict(
|
||||
test_dmatrix,
|
||||
data=test,
|
||||
iteration_range=iteration_range,
|
||||
output_margin=output_margin,
|
||||
ntree_limit=ntree_limit,
|
||||
validate_features=validate_features
|
||||
validate_features=validate_features,
|
||||
)
|
||||
|
||||
def apply(self, X, ntree_limit=0):
|
||||
def apply(
|
||||
self, X, ntree_limit: int = 0, iteration_range: Optional[Tuple[int, int]] = None
|
||||
) -> np.ndarray:
|
||||
"""Return the predicted leaf every tree for each sample.
|
||||
|
||||
Parameters
|
||||
@ -823,10 +846,16 @@ class XGBModel(XGBModelBase):
|
||||
leaf x ends up in. Leaves are numbered within
|
||||
``[0; 2**(self.max_depth+1))``, possibly with gaps in the numbering.
|
||||
"""
|
||||
iteration_range = _convert_ntree_limit(
|
||||
self.get_booster(), ntree_limit, iteration_range
|
||||
)
|
||||
iteration_range = self._get_iteration_range(iteration_range)
|
||||
test_dmatrix = DMatrix(X, missing=self.missing, nthread=self.n_jobs)
|
||||
return self.get_booster().predict(test_dmatrix,
|
||||
pred_leaf=True,
|
||||
ntree_limit=ntree_limit)
|
||||
return self.get_booster().predict(
|
||||
test_dmatrix,
|
||||
pred_leaf=True,
|
||||
iteration_range=iteration_range
|
||||
)
|
||||
|
||||
def evals_result(self):
|
||||
"""Return the evaluation results.
|
||||
@ -945,8 +974,7 @@ class XGBModel(XGBModelBase):
|
||||
'Coefficients are not defined for Booster type {}'
|
||||
.format(self.booster))
|
||||
b = self.get_booster()
|
||||
coef = np.array(json.loads(
|
||||
b.get_dump(dump_format='json')[0])['weight'])
|
||||
coef = np.array(json.loads(b.get_dump(dump_format='json')[0])['weight'])
|
||||
# Logic for multiclass classification
|
||||
n_classes = getattr(self, 'n_classes_', None)
|
||||
if n_classes is not None:
|
||||
@ -1157,14 +1185,16 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
output_margin=False,
|
||||
ntree_limit=None,
|
||||
validate_features=True,
|
||||
base_margin=None
|
||||
base_margin=None,
|
||||
iteration_range: Optional[Tuple[int, int]] = None,
|
||||
):
|
||||
class_probs = super().predict(
|
||||
X=X,
|
||||
output_margin=output_margin,
|
||||
ntree_limit=ntree_limit,
|
||||
validate_features=validate_features,
|
||||
base_margin=base_margin
|
||||
base_margin=base_margin,
|
||||
iteration_range=iteration_range,
|
||||
)
|
||||
if output_margin:
|
||||
# If output_margin is active, simply return the scores
|
||||
@ -1180,29 +1210,34 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
return self._le.inverse_transform(column_indexes)
|
||||
return column_indexes
|
||||
|
||||
def predict_proba(self, X, ntree_limit=None, validate_features=False,
|
||||
base_margin=None):
|
||||
def predict_proba(
|
||||
self,
|
||||
X,
|
||||
ntree_limit=None,
|
||||
validate_features=False,
|
||||
base_margin=None,
|
||||
iteration_range: Optional[Tuple[int, int]] = None,
|
||||
) -> np.ndarray:
|
||||
""" Predict the probability of each `X` example being of a given class.
|
||||
|
||||
.. note:: This function is not thread safe
|
||||
|
||||
For each booster object, predict can only be called from one
|
||||
thread. If you want to run prediction using multiple thread, call
|
||||
``xgb.copy()`` to make copies of model object and then call predict
|
||||
.. note:: This function is only thread safe for `gbtree`
|
||||
|
||||
Parameters
|
||||
----------
|
||||
X : array_like
|
||||
Feature matrix.
|
||||
ntree_limit : int
|
||||
Limit number of trees in the prediction; defaults to best_ntree_limit if
|
||||
defined (i.e. it has been trained with early stopping), otherwise 0 (use all
|
||||
trees).
|
||||
Deprecated, use `iteration_range` instead.
|
||||
validate_features : bool
|
||||
When this is True, validate that the Booster's and data's feature_names are
|
||||
identical. Otherwise, it is assumed that the feature_names are the same.
|
||||
base_margin : array_like
|
||||
Margin added to prediction.
|
||||
iteration_range :
|
||||
Specifies which layer of trees are used in prediction. For example, if a
|
||||
random forest is trained with 100 rounds. Specifying `iteration_range=(10,
|
||||
20)`, then only the forests built during [10, 20) (half open set) rounds are
|
||||
used in this prediction.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -1215,7 +1250,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
output_margin=False,
|
||||
ntree_limit=ntree_limit,
|
||||
validate_features=validate_features,
|
||||
base_margin=base_margin
|
||||
base_margin=base_margin,
|
||||
iteration_range=iteration_range
|
||||
)
|
||||
return _cls_predict_proba(self.objective, class_probs, np.vstack)
|
||||
|
||||
|
||||
@ -4,9 +4,8 @@
|
||||
"""Training Library containing training routines."""
|
||||
import warnings
|
||||
import copy
|
||||
import json
|
||||
import numpy as np
|
||||
from .core import Booster, XGBoostError
|
||||
from .core import Booster, XGBoostError, _get_booster_layer_trees
|
||||
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
|
||||
from . import callback
|
||||
|
||||
@ -91,24 +90,7 @@ def _train_internal(params, dtrain,
|
||||
# These should be moved into callback functions `after_training`, but until old
|
||||
# callbacks are removed, the train function is the only place for setting the
|
||||
# attributes.
|
||||
config = json.loads(bst.save_config())
|
||||
booster = config['learner']['gradient_booster']['name']
|
||||
if booster == 'gblinear':
|
||||
num_parallel_tree = 0
|
||||
elif booster == 'dart':
|
||||
num_parallel_tree = int(
|
||||
config['learner']['gradient_booster']['gbtree']['gbtree_train_param'][
|
||||
'num_parallel_tree'
|
||||
]
|
||||
)
|
||||
elif booster == 'gbtree':
|
||||
num_parallel_tree = int(
|
||||
config['learner']['gradient_booster']['gbtree_train_param'][
|
||||
'num_parallel_tree']
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Unknown booster: {booster}')
|
||||
|
||||
num_parallel_tree, _ = _get_booster_layer_trees(bst)
|
||||
if bst.attr('best_score') is not None:
|
||||
bst.best_score = float(bst.attr('best_score'))
|
||||
bst.best_iteration = int(bst.attr('best_iteration'))
|
||||
|
||||
@ -619,20 +619,58 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
||||
CHECK_HANDLE();
|
||||
auto *learner = static_cast<Learner*>(handle);
|
||||
auto& entry = learner->GetThreadLocal().prediction_entry;
|
||||
learner->Predict(
|
||||
*static_cast<std::shared_ptr<DMatrix>*>(dmat),
|
||||
(option_mask & 1) != 0,
|
||||
&entry.predictions, ntree_limit,
|
||||
static_cast<bool>(training),
|
||||
(option_mask & 2) != 0,
|
||||
(option_mask & 4) != 0,
|
||||
(option_mask & 8) != 0,
|
||||
(option_mask & 16) != 0);
|
||||
auto iteration_end = GetIterationFromTreeLimit(ntree_limit, learner);
|
||||
learner->Predict(*static_cast<std::shared_ptr<DMatrix> *>(dmat),
|
||||
(option_mask & 1) != 0, &entry.predictions, 0, iteration_end,
|
||||
static_cast<bool>(training), (option_mask & 2) != 0,
|
||||
(option_mask & 4) != 0, (option_mask & 8) != 0,
|
||||
(option_mask & 16) != 0);
|
||||
*out_result = dmlc::BeginPtr(entry.predictions.ConstHostVector());
|
||||
*len = static_cast<xgboost::bst_ulong>(entry.predictions.Size());
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle,
|
||||
DMatrixHandle dmat,
|
||||
char const* c_json_config,
|
||||
xgboost::bst_ulong const **out_shape,
|
||||
xgboost::bst_ulong *out_dim,
|
||||
bst_float const **out_result) {
|
||||
API_BEGIN();
|
||||
if (handle == nullptr) {
|
||||
LOG(FATAL) << "Booster has not been intialized or has already been disposed.";
|
||||
}
|
||||
if (dmat == nullptr) {
|
||||
LOG(FATAL) << "DMatrix has not been intialized or has already been disposed.";
|
||||
}
|
||||
auto config = Json::Load(StringView{c_json_config});
|
||||
|
||||
auto *learner = static_cast<Learner*>(handle);
|
||||
auto& entry = learner->GetThreadLocal().prediction_entry;
|
||||
auto p_m = *static_cast<std::shared_ptr<DMatrix> *>(dmat);
|
||||
auto type = PredictionType(get<Integer const>(config["type"]));
|
||||
auto iteration_begin = get<Integer const>(config["iteration_begin"]);
|
||||
auto iteration_end = get<Integer const>(config["iteration_end"]);
|
||||
learner->Predict(
|
||||
*static_cast<std::shared_ptr<DMatrix> *>(dmat),
|
||||
type == PredictionType::kMargin, &entry.predictions, iteration_begin,
|
||||
iteration_end, get<Boolean const>(config["training"]),
|
||||
type == PredictionType::kLeaf, type == PredictionType::kContribution,
|
||||
type == PredictionType::kApproxContribution,
|
||||
type == PredictionType::kInteraction);
|
||||
*out_result = dmlc::BeginPtr(entry.predictions.ConstHostVector());
|
||||
auto &shape = learner->GetThreadLocal().prediction_shape;
|
||||
auto chunksize = p_m->Info().num_row_ == 0 ? 0 : entry.predictions.Size() / p_m->Info().num_row_;
|
||||
auto rounds = iteration_end - iteration_begin;
|
||||
rounds = rounds == 0 ? learner->BoostedRounds() : rounds;
|
||||
// Determine shape
|
||||
bool strict_shape = get<Boolean const>(config["strict_shape"]);
|
||||
CalcPredictShape(strict_shape, type, p_m->Info().num_row_,
|
||||
p_m->Info().num_col_, chunksize, learner->Groups(), rounds,
|
||||
&shape, out_dim);
|
||||
*out_shape = dmlc::BeginPtr(shape);
|
||||
API_END();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void InplacePredictImpl(std::shared_ptr<T> x, std::shared_ptr<DMatrix> p_m,
|
||||
@ -705,7 +743,7 @@ XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, char const *indptr,
|
||||
}
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterface(
|
||||
XGB_DLL int XGBoosterPredictFromCUDAArray(
|
||||
BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
|
||||
DMatrixHandle m, xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim,
|
||||
const float **out_result) {
|
||||
@ -715,7 +753,7 @@ XGB_DLL int XGBoosterPredictFromArrayInterface(
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(
|
||||
XGB_DLL int XGBoosterPredictFromCUDAColumnar(
|
||||
BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
|
||||
DMatrixHandle m, xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim,
|
||||
const float **out_result) {
|
||||
|
||||
@ -66,8 +66,7 @@ int InplacePreidctCuda(BoosterHandle handle, char const *c_json_strs,
|
||||
API_END();
|
||||
}
|
||||
|
||||
// A hidden API as cache id is not being supported yet.
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(
|
||||
XGB_DLL int XGBoosterPredictFromCudaColumnar(
|
||||
BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
|
||||
DMatrixHandle m, xgboost::bst_ulong const **out_shape,
|
||||
xgboost::bst_ulong *out_dim, const float **out_result) {
|
||||
@ -79,8 +78,7 @@ XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(
|
||||
handle, c_json_strs, c_json_config, p_m, out_shape, out_dim, out_result);
|
||||
}
|
||||
|
||||
// A hidden API as cache id is not being supported yet.
|
||||
XGB_DLL int XGBoosterPredictFromArrayInterface(
|
||||
XGB_DLL int XGBoosterPredictFromCudaArray(
|
||||
BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
|
||||
DMatrixHandle m, xgboost::bst_ulong const **out_shape,
|
||||
xgboost::bst_ulong *out_dim, const float **out_result) {
|
||||
|
||||
@ -9,6 +9,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/learner.h"
|
||||
|
||||
namespace xgboost {
|
||||
@ -30,8 +31,8 @@ inline void CalcPredictShape(bool strict_shape, PredictionType type, size_t rows
|
||||
std::vector<bst_ulong> *out_shape,
|
||||
xgboost::bst_ulong *out_dim) {
|
||||
auto &shape = *out_shape;
|
||||
if ((type == PredictionType::kMargin || type == PredictionType::kValue) &&
|
||||
rows != 0) {
|
||||
if (type == PredictionType::kMargin && rows != 0) {
|
||||
// When kValue is used, softmax can change the chunksize.
|
||||
CHECK_EQ(chunksize, groups);
|
||||
}
|
||||
|
||||
@ -110,5 +111,35 @@ inline void CalcPredictShape(bool strict_shape, PredictionType type, size_t rows
|
||||
std::accumulate(shape.cbegin(), shape.cend(), 1, std::multiplies<>{}),
|
||||
chunksize * rows);
|
||||
}
|
||||
|
||||
// Reverse the ntree_limit in old prediction API.
|
||||
inline uint32_t GetIterationFromTreeLimit(uint32_t ntree_limit, Learner *learner) {
|
||||
// On Python and R, `best_ntree_limit` is set to `best_iteration * num_parallel_tree`.
|
||||
// To reverse it we just divide it by `num_parallel_tree`.
|
||||
if (ntree_limit != 0) {
|
||||
learner->Configure();
|
||||
uint32_t num_parallel_tree = 0;
|
||||
|
||||
Json config{Object()};
|
||||
learner->SaveConfig(&config);
|
||||
auto const &booster =
|
||||
get<String const>(config["learner"]["gradient_booster"]["name"]);
|
||||
if (booster == "gblinear") {
|
||||
num_parallel_tree = 0;
|
||||
} else if (booster == "dart") {
|
||||
num_parallel_tree = std::stoi(
|
||||
get<String const>(config["learner"]["gradient_booster"]["gbtree"]
|
||||
["gbtree_train_param"]["num_parallel_tree"]));
|
||||
} else if (booster == "gbtree") {
|
||||
num_parallel_tree = std::stoi(get<String const>(
|
||||
(config["learner"]["gradient_booster"]["gbtree_train_param"]
|
||||
["num_parallel_tree"])));
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown booster:" << booster;
|
||||
}
|
||||
ntree_limit /= std::max(num_parallel_tree, 1u);
|
||||
}
|
||||
return ntree_limit;
|
||||
}
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_C_API_C_API_UTILS_H_
|
||||
|
||||
@ -25,6 +25,7 @@
|
||||
#include "common/config.h"
|
||||
#include "common/io.h"
|
||||
#include "common/version.h"
|
||||
#include "c_api/c_api_utils.h"
|
||||
|
||||
namespace xgboost {
|
||||
enum CLITask {
|
||||
@ -58,6 +59,8 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
|
||||
int dsplit;
|
||||
/*!\brief limit number of trees in prediction */
|
||||
int ntree_limit;
|
||||
int iteration_begin;
|
||||
int iteration_end;
|
||||
/*!\brief whether to directly output margin value */
|
||||
bool pred_margin;
|
||||
/*! \brief whether dump statistics along with model */
|
||||
@ -109,7 +112,11 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
|
||||
.add_enum("row", 2)
|
||||
.describe("Data split mode.");
|
||||
DMLC_DECLARE_FIELD(ntree_limit).set_default(0).set_lower_bound(0)
|
||||
.describe("Number of trees used for prediction, 0 means use all trees.");
|
||||
.describe("(Deprecated) Use iteration_begin/iteration_end instead.");
|
||||
DMLC_DECLARE_FIELD(iteration_begin).set_default(0).set_lower_bound(0)
|
||||
.describe("Begining of boosted tree iteration used for prediction.");
|
||||
DMLC_DECLARE_FIELD(iteration_end).set_default(0).set_lower_bound(0)
|
||||
.describe("End of boosted tree iteration used for prediction. 0 means all the trees.");
|
||||
DMLC_DECLARE_FIELD(pred_margin).set_default(false)
|
||||
.describe("Whether to predict margin value instead of probability.");
|
||||
DMLC_DECLARE_FIELD(dump_stats).set_default(false)
|
||||
@ -334,7 +341,13 @@ class CLI {
|
||||
|
||||
LOG(INFO) << "Start prediction...";
|
||||
HostDeviceVector<bst_float> preds;
|
||||
learner_->Predict(dtest, param_.pred_margin, &preds, param_.ntree_limit);
|
||||
if (param_.ntree_limit != 0) {
|
||||
param_.iteration_end = GetIterationFromTreeLimit(param_.ntree_limit, learner_.get());
|
||||
LOG(WARNING) << "`ntree_limit` is deprecated, use `iteration_begin` and "
|
||||
"`iteration_end` instead.";
|
||||
}
|
||||
learner_->Predict(dtest, param_.pred_margin, &preds, param_.iteration_begin,
|
||||
param_.iteration_end);
|
||||
LOG(CONSOLE) << "Writing prediction to " << param_.name_pred;
|
||||
|
||||
std::unique_ptr<dmlc::Stream> fo(
|
||||
|
||||
@ -47,6 +47,12 @@ struct GBLinearTrainParam : public XGBoostParameter<GBLinearTrainParam> {
|
||||
.describe("Maximum rows per batch.");
|
||||
}
|
||||
};
|
||||
|
||||
void LinearCheckLayer(unsigned layer_begin, unsigned layer_end) {
|
||||
CHECK_EQ(layer_begin, 0) << "Linear booster does not support prediction range.";
|
||||
CHECK_EQ(layer_end, 0) << "Linear booster does not support prediction range.";
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief gradient boosted linear model
|
||||
*/
|
||||
@ -130,20 +136,19 @@ class GBLinear : public GradientBooster {
|
||||
monitor_.Stop("DoBoost");
|
||||
}
|
||||
|
||||
void PredictBatch(DMatrix *p_fmat,
|
||||
PredictionCacheEntry *predts,
|
||||
bool, unsigned ntree_limit) override {
|
||||
void PredictBatch(DMatrix *p_fmat, PredictionCacheEntry *predts,
|
||||
bool training, unsigned layer_begin, unsigned layer_end) override {
|
||||
monitor_.Start("PredictBatch");
|
||||
LinearCheckLayer(layer_begin, layer_end);
|
||||
auto* out_preds = &predts->predictions;
|
||||
CHECK_EQ(ntree_limit, 0U)
|
||||
<< "GBLinear::Predict ntrees is only valid for gbtree predictor";
|
||||
this->PredictBatchInternal(p_fmat, &out_preds->HostVector());
|
||||
monitor_.Stop("PredictBatch");
|
||||
}
|
||||
// add base margin
|
||||
void PredictInstance(const SparsePage::Inst &inst,
|
||||
std::vector<bst_float> *out_preds,
|
||||
unsigned) override {
|
||||
unsigned layer_begin, unsigned layer_end) override {
|
||||
LinearCheckLayer(layer_begin, layer_end);
|
||||
const int ngroup = model_.learner_model_param->num_output_group;
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
this->Pred(inst, dmlc::BeginPtr(*out_preds), gid,
|
||||
@ -151,16 +156,15 @@ class GBLinear : public GradientBooster {
|
||||
}
|
||||
}
|
||||
|
||||
void PredictLeaf(DMatrix *, HostDeviceVector<bst_float> *, unsigned) override {
|
||||
void PredictLeaf(DMatrix *, HostDeviceVector<bst_float> *, unsigned, unsigned) override {
|
||||
LOG(FATAL) << "gblinear does not support prediction of leaf index";
|
||||
}
|
||||
|
||||
void PredictContribution(DMatrix* p_fmat,
|
||||
HostDeviceVector<bst_float>* out_contribs,
|
||||
unsigned ntree_limit, bool, int, unsigned) override {
|
||||
unsigned layer_begin, unsigned layer_end, bool, int, unsigned) override {
|
||||
model_.LazyInitModel();
|
||||
CHECK_EQ(ntree_limit, 0U)
|
||||
<< "GBLinear::PredictContribution: ntrees is only valid for gbtree predictor";
|
||||
LinearCheckLayer(layer_begin, layer_end);
|
||||
const auto& base_margin = p_fmat->Info().base_margin_.ConstHostVector();
|
||||
const int ngroup = model_.learner_model_param->num_output_group;
|
||||
const size_t ncolumns = model_.learner_model_param->num_feature + 1;
|
||||
@ -197,7 +201,8 @@ class GBLinear : public GradientBooster {
|
||||
|
||||
void PredictInteractionContributions(DMatrix* p_fmat,
|
||||
HostDeviceVector<bst_float>* out_contribs,
|
||||
unsigned, bool) override {
|
||||
unsigned layer_begin, unsigned layer_end, bool) override {
|
||||
LinearCheckLayer(layer_begin, layer_end);
|
||||
std::vector<bst_float>& contribs = out_contribs->HostVector();
|
||||
|
||||
// linear models have no interaction effects
|
||||
|
||||
@ -414,7 +414,7 @@ void GBTree::Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
|
||||
auto layer_trees = this->LayerTrees();
|
||||
|
||||
layer_end = layer_end == 0 ? model_.trees.size() / layer_trees : layer_end;
|
||||
CHECK_GE(layer_end, layer_begin);
|
||||
CHECK_GT(layer_end, layer_begin);
|
||||
CHECK_GE(step, 1);
|
||||
int32_t n_layers = (layer_end - layer_begin) / step;
|
||||
std::vector<std::unique_ptr<RegTree>> &out_trees = out_model.trees;
|
||||
@ -438,10 +438,35 @@ void GBTree::Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
|
||||
void GBTree::PredictBatch(DMatrix* p_fmat,
|
||||
PredictionCacheEntry* out_preds,
|
||||
bool,
|
||||
unsigned ntree_limit) {
|
||||
unsigned layer_begin,
|
||||
unsigned layer_end) {
|
||||
CHECK(configured_);
|
||||
if (layer_end == 0) {
|
||||
layer_end = this->BoostedRounds();
|
||||
}
|
||||
if (layer_begin != 0 || layer_end < out_preds->version) {
|
||||
// cache is dropped.
|
||||
out_preds->version = 0;
|
||||
}
|
||||
bool reset = false;
|
||||
if (layer_begin == 0) {
|
||||
layer_begin = out_preds->version;
|
||||
} else {
|
||||
// When begin layer is not 0, the cache is not useful.
|
||||
reset = true;
|
||||
}
|
||||
|
||||
uint32_t tree_begin, tree_end;
|
||||
std::tie(tree_begin, tree_end) =
|
||||
detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
|
||||
GetPredictor(&out_preds->predictions, p_fmat)
|
||||
->PredictBatch(p_fmat, out_preds, model_, 0, ntree_limit);
|
||||
->PredictBatch(p_fmat, out_preds, model_, tree_begin, tree_end);
|
||||
if (reset) {
|
||||
out_preds->version = 0;
|
||||
} else {
|
||||
uint32_t delta = layer_end - out_preds->version;
|
||||
out_preds->Update(delta);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<Predictor> const &
|
||||
@ -603,13 +628,14 @@ class Dart : public GBTree {
|
||||
void PredictBatch(DMatrix* p_fmat,
|
||||
PredictionCacheEntry* p_out_preds,
|
||||
bool training,
|
||||
unsigned ntree_limit) override {
|
||||
unsigned layer_begin,
|
||||
unsigned layer_end) override {
|
||||
DropTrees(training);
|
||||
int num_group = model_.learner_model_param->num_output_group;
|
||||
ntree_limit *= num_group;
|
||||
if (ntree_limit == 0 || ntree_limit > model_.trees.size()) {
|
||||
ntree_limit = static_cast<unsigned>(model_.trees.size());
|
||||
}
|
||||
uint32_t tree_begin, tree_end;
|
||||
std::tie(tree_begin, tree_end) =
|
||||
detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
|
||||
|
||||
size_t n = num_group * p_fmat->Info().num_row_;
|
||||
const auto &base_margin = p_fmat->Info().base_margin_.ConstHostVector();
|
||||
auto& out_preds = p_out_preds->predictions.HostVector();
|
||||
@ -623,26 +649,24 @@ class Dart : public GBTree {
|
||||
}
|
||||
const int nthread = omp_get_max_threads();
|
||||
InitThreadTemp(nthread);
|
||||
PredLoopSpecalize(p_fmat, &out_preds, num_group, 0, ntree_limit);
|
||||
PredLoopSpecalize(p_fmat, &out_preds, num_group, tree_begin, tree_end);
|
||||
}
|
||||
|
||||
void PredictInstance(const SparsePage::Inst &inst,
|
||||
std::vector<bst_float> *out_preds,
|
||||
unsigned ntree_limit) override {
|
||||
unsigned layer_begin, unsigned layer_end) override {
|
||||
DropTrees(false);
|
||||
if (thread_temp_.size() == 0) {
|
||||
thread_temp_.resize(1, RegTree::FVec());
|
||||
thread_temp_[0].Init(model_.learner_model_param->num_feature);
|
||||
}
|
||||
out_preds->resize(model_.learner_model_param->num_output_group);
|
||||
ntree_limit *= model_.learner_model_param->num_output_group;
|
||||
if (ntree_limit == 0 || ntree_limit > model_.trees.size()) {
|
||||
ntree_limit = static_cast<unsigned>(model_.trees.size());
|
||||
}
|
||||
uint32_t tree_begin, tree_end;
|
||||
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
|
||||
// loop over output groups
|
||||
for (uint32_t gid = 0; gid < model_.learner_model_param->num_output_group; ++gid) {
|
||||
(*out_preds)[gid] =
|
||||
PredValue(inst, gid, &thread_temp_[0], 0, ntree_limit) +
|
||||
PredValue(inst, gid, &thread_temp_[0], 0, tree_end) +
|
||||
model_.learner_model_param->base_score;
|
||||
}
|
||||
}
|
||||
@ -653,22 +677,25 @@ class Dart : public GBTree {
|
||||
|
||||
void PredictContribution(DMatrix* p_fmat,
|
||||
HostDeviceVector<bst_float>* out_contribs,
|
||||
unsigned ntree_limit, bool approximate, int,
|
||||
unsigned layer_begin, unsigned layer_end, bool approximate, int,
|
||||
unsigned) override {
|
||||
CHECK(configured_);
|
||||
uint32_t tree_begin, tree_end;
|
||||
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
|
||||
cpu_predictor_->PredictContribution(p_fmat, out_contribs, model_,
|
||||
ntree_limit, &weight_drop_, approximate);
|
||||
tree_end, &weight_drop_, approximate);
|
||||
}
|
||||
|
||||
void PredictInteractionContributions(DMatrix* p_fmat,
|
||||
HostDeviceVector<bst_float>* out_contribs,
|
||||
unsigned ntree_limit, bool approximate) override {
|
||||
void PredictInteractionContributions(
|
||||
DMatrix *p_fmat, HostDeviceVector<bst_float> *out_contribs,
|
||||
unsigned layer_begin, unsigned layer_end, bool approximate) override {
|
||||
CHECK(configured_);
|
||||
uint32_t tree_begin, tree_end;
|
||||
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
|
||||
cpu_predictor_->PredictInteractionContributions(p_fmat, out_contribs, model_,
|
||||
ntree_limit, &weight_drop_, approximate);
|
||||
tree_end, &weight_drop_, approximate);
|
||||
}
|
||||
|
||||
|
||||
protected:
|
||||
inline void PredLoopSpecalize(
|
||||
DMatrix* p_fmat,
|
||||
|
||||
@ -164,7 +164,9 @@ inline std::pair<uint32_t, uint32_t> LayerToTree(gbm::GBTreeModel const &model,
|
||||
if (tree_end == 0) {
|
||||
tree_end = static_cast<uint32_t>(model.trees.size());
|
||||
}
|
||||
CHECK_LT(tree_begin, tree_end);
|
||||
if (model.trees.size() != 0) {
|
||||
CHECK_LE(tree_begin, tree_end);
|
||||
}
|
||||
return {tree_begin, tree_end};
|
||||
}
|
||||
|
||||
@ -260,10 +262,8 @@ class GBTree : public GradientBooster {
|
||||
return model_.trees.size() / this->LayerTrees();
|
||||
}
|
||||
|
||||
void PredictBatch(DMatrix* p_fmat,
|
||||
PredictionCacheEntry* out_preds,
|
||||
bool training,
|
||||
unsigned ntree_limit) override;
|
||||
void PredictBatch(DMatrix *p_fmat, PredictionCacheEntry *out_preds,
|
||||
bool training, unsigned layer_begin, unsigned layer_end) override;
|
||||
|
||||
void InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
||||
float missing, PredictionCacheEntry *out_preds,
|
||||
@ -297,33 +297,49 @@ class GBTree : public GradientBooster {
|
||||
|
||||
void PredictInstance(const SparsePage::Inst& inst,
|
||||
std::vector<bst_float>* out_preds,
|
||||
unsigned ntree_limit) override {
|
||||
uint32_t layer_begin, uint32_t layer_end) override {
|
||||
CHECK(configured_);
|
||||
uint32_t tree_begin, tree_end;
|
||||
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
|
||||
cpu_predictor_->PredictInstance(inst, out_preds, model_,
|
||||
ntree_limit);
|
||||
tree_end);
|
||||
}
|
||||
|
||||
void PredictLeaf(DMatrix* p_fmat,
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
unsigned ntree_limit) override {
|
||||
this->GetPredictor()->PredictLeaf(p_fmat, out_preds, model_, ntree_limit);
|
||||
uint32_t layer_begin, uint32_t layer_end) override {
|
||||
uint32_t tree_begin, tree_end;
|
||||
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
|
||||
CHECK_EQ(tree_begin, 0) << "Predict leaf supports only iteration end: (0, "
|
||||
"n_iteration), use model slicing instead.";
|
||||
this->GetPredictor()->PredictLeaf(p_fmat, out_preds, model_, tree_end);
|
||||
}
|
||||
|
||||
void PredictContribution(DMatrix* p_fmat,
|
||||
HostDeviceVector<bst_float>* out_contribs,
|
||||
unsigned ntree_limit, bool approximate,
|
||||
uint32_t layer_begin, uint32_t layer_end, bool approximate,
|
||||
int, unsigned) override {
|
||||
CHECK(configured_);
|
||||
uint32_t tree_begin, tree_end;
|
||||
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
|
||||
CHECK_EQ(tree_begin, 0)
|
||||
<< "Predict contribution supports only iteration end: (0, "
|
||||
"n_iteration), using model slicing instead.";
|
||||
this->GetPredictor()->PredictContribution(
|
||||
p_fmat, out_contribs, model_, ntree_limit, nullptr, approximate);
|
||||
p_fmat, out_contribs, model_, tree_end, nullptr, approximate);
|
||||
}
|
||||
|
||||
void PredictInteractionContributions(DMatrix* p_fmat,
|
||||
HostDeviceVector<bst_float>* out_contribs,
|
||||
unsigned ntree_limit, bool approximate) override {
|
||||
void PredictInteractionContributions(
|
||||
DMatrix *p_fmat, HostDeviceVector<bst_float> *out_contribs,
|
||||
uint32_t layer_begin, uint32_t layer_end, bool approximate) override {
|
||||
CHECK(configured_);
|
||||
this->GetPredictor()->PredictInteractionContributions(p_fmat, out_contribs, model_,
|
||||
ntree_limit, nullptr, approximate);
|
||||
uint32_t tree_begin, tree_end;
|
||||
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
|
||||
CHECK_EQ(tree_begin, 0)
|
||||
<< "Predict interaction contribution supports only iteration end: (0, "
|
||||
"n_iteration), using model slicing instead.";
|
||||
this->GetPredictor()->PredictInteractionContributions(
|
||||
p_fmat, out_contribs, model_, tree_end, nullptr, approximate);
|
||||
}
|
||||
|
||||
std::vector<std::string> DumpModel(const FeatureMap& fmap,
|
||||
|
||||
@ -22,6 +22,7 @@
|
||||
|
||||
#include "dmlc/any.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/c_api.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/model.h"
|
||||
#include "xgboost/predictor.h"
|
||||
@ -996,7 +997,7 @@ class LearnerImpl : public LearnerIO {
|
||||
auto& predt = local_cache->Cache(train, generic_parameters_.gpu_id);
|
||||
|
||||
monitor_.Start("PredictRaw");
|
||||
this->PredictRaw(train.get(), &predt, true);
|
||||
this->PredictRaw(train.get(), &predt, true, 0, 0);
|
||||
TrainingObserver::Instance().Observe(predt.predictions, "Predictions");
|
||||
monitor_.Stop("PredictRaw");
|
||||
|
||||
@ -1057,7 +1058,7 @@ class LearnerImpl : public LearnerIO {
|
||||
std::shared_ptr<DMatrix> m = data_sets[i];
|
||||
auto &predt = local_cache->Cache(m, generic_parameters_.gpu_id);
|
||||
this->ValidateDMatrix(m.get(), false);
|
||||
this->PredictRaw(m.get(), &predt, false);
|
||||
this->PredictRaw(m.get(), &predt, false, 0, 0);
|
||||
|
||||
auto &out = output_predictions_.Cache(m, generic_parameters_.gpu_id).predictions;
|
||||
out.Resize(predt.predictions.Size());
|
||||
@ -1075,8 +1076,8 @@ class LearnerImpl : public LearnerIO {
|
||||
}
|
||||
|
||||
void Predict(std::shared_ptr<DMatrix> data, bool output_margin,
|
||||
HostDeviceVector<bst_float>* out_preds, unsigned ntree_limit,
|
||||
bool training,
|
||||
HostDeviceVector<bst_float> *out_preds, unsigned layer_begin,
|
||||
unsigned layer_end, bool training,
|
||||
bool pred_leaf, bool pred_contribs, bool approx_contribs,
|
||||
bool pred_interactions) override {
|
||||
int multiple_predictions = static_cast<int>(pred_leaf) +
|
||||
@ -1085,16 +1086,16 @@ class LearnerImpl : public LearnerIO {
|
||||
this->Configure();
|
||||
CHECK_LE(multiple_predictions, 1) << "Perform one kind of prediction at a time.";
|
||||
if (pred_contribs) {
|
||||
gbm_->PredictContribution(data.get(), out_preds, ntree_limit, approx_contribs);
|
||||
gbm_->PredictContribution(data.get(), out_preds, layer_begin, layer_end, approx_contribs);
|
||||
} else if (pred_interactions) {
|
||||
gbm_->PredictInteractionContributions(data.get(), out_preds, ntree_limit,
|
||||
gbm_->PredictInteractionContributions(data.get(), out_preds, layer_begin, layer_end,
|
||||
approx_contribs);
|
||||
} else if (pred_leaf) {
|
||||
gbm_->PredictLeaf(data.get(), out_preds, ntree_limit);
|
||||
gbm_->PredictLeaf(data.get(), out_preds, layer_begin, layer_end);
|
||||
} else {
|
||||
auto local_cache = this->GetPredictionCache();
|
||||
auto& prediction = local_cache->Cache(data, generic_parameters_.gpu_id);
|
||||
this->PredictRaw(data.get(), &prediction, training, ntree_limit);
|
||||
this->PredictRaw(data.get(), &prediction, training, layer_begin, layer_end);
|
||||
// Copy the prediction cache to output prediction. out_preds comes from C API
|
||||
out_preds->SetDevice(generic_parameters_.gpu_id);
|
||||
out_preds->Resize(prediction.predictions.Size());
|
||||
@ -1151,12 +1152,11 @@ class LearnerImpl : public LearnerIO {
|
||||
* predictor, when it equals 0, this means we are using all the trees
|
||||
* \param training allow dropout when the DART booster is being used
|
||||
*/
|
||||
void PredictRaw(DMatrix* data, PredictionCacheEntry* out_preds,
|
||||
bool training,
|
||||
unsigned ntree_limit = 0) const {
|
||||
void PredictRaw(DMatrix *data, PredictionCacheEntry *out_preds, bool training,
|
||||
unsigned layer_begin, unsigned layer_end) const {
|
||||
CHECK(gbm_ != nullptr) << "Predict must happen after Load or configuration";
|
||||
this->ValidateDMatrix(data, false);
|
||||
gbm_->PredictBatch(data, out_preds, training, ntree_limit);
|
||||
gbm_->PredictBatch(data, out_preds, training, layer_begin, layer_end);
|
||||
}
|
||||
|
||||
void ValidateDMatrix(DMatrix* p_fmat, bool is_training) const {
|
||||
|
||||
@ -234,56 +234,28 @@ class CPUPredictor : public Predictor {
|
||||
public:
|
||||
explicit CPUPredictor(GenericParameter const* generic_param) :
|
||||
Predictor::Predictor{generic_param} {}
|
||||
// ntree_limit is a very problematic parameter, as it's ambiguous in the context of
|
||||
// multi-output and forest. Same problem exists for tree_begin
|
||||
void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts,
|
||||
const gbm::GBTreeModel& model, int tree_begin,
|
||||
uint32_t const ntree_limit = 0) const override {
|
||||
// tree_begin is not used, right now we just enforce it to be 0.
|
||||
CHECK_EQ(tree_begin, 0);
|
||||
void PredictBatch(DMatrix *dmat, PredictionCacheEntry *predts,
|
||||
const gbm::GBTreeModel &model, uint32_t tree_begin,
|
||||
uint32_t tree_end = 0) const override {
|
||||
auto* out_preds = &predts->predictions;
|
||||
CHECK_GE(predts->version, tree_begin);
|
||||
if (out_preds->Size() == 0 && dmat->Info().num_row_ != 0) {
|
||||
CHECK_EQ(predts->version, 0);
|
||||
}
|
||||
// This is actually already handled in gbm, but large amount of tests rely on the
|
||||
// behaviour.
|
||||
if (tree_end == 0) {
|
||||
tree_end = model.trees.size();
|
||||
}
|
||||
if (predts->version == 0) {
|
||||
// out_preds->Size() can be non-zero as it's initialized here before any tree is
|
||||
// built at the 0^th iterator.
|
||||
this->InitOutPredictions(dmat->Info(), out_preds, model);
|
||||
}
|
||||
|
||||
uint32_t const output_groups = model.learner_model_param->num_output_group;
|
||||
CHECK_NE(output_groups, 0);
|
||||
// Right now we just assume ntree_limit provided by users means number of tree layers
|
||||
// in the context of multi-output model
|
||||
uint32_t real_ntree_limit = ntree_limit * output_groups;
|
||||
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
|
||||
real_ntree_limit = static_cast<uint32_t>(model.trees.size());
|
||||
if (tree_end - tree_begin == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t const end_version = (tree_begin + real_ntree_limit) / output_groups;
|
||||
// When users have provided ntree_limit, end_version can be lesser, cache is violated
|
||||
if (predts->version > end_version) {
|
||||
CHECK_NE(ntree_limit, 0);
|
||||
this->InitOutPredictions(dmat->Info(), out_preds, model);
|
||||
predts->version = 0;
|
||||
}
|
||||
uint32_t const beg_version = predts->version;
|
||||
CHECK_LE(beg_version, end_version);
|
||||
|
||||
if (beg_version < end_version) {
|
||||
this->PredictDMatrix(dmat, &out_preds->HostVector(), model,
|
||||
beg_version * output_groups,
|
||||
end_version * output_groups);
|
||||
}
|
||||
|
||||
// delta means {size of forest} * {number of newly accumulated layers}
|
||||
uint32_t delta = end_version - beg_version;
|
||||
CHECK_LE(delta, model.trees.size());
|
||||
predts->Update(delta);
|
||||
|
||||
CHECK(out_preds->Size() == output_groups * dmat->Info().num_row_ ||
|
||||
out_preds->Size() == dmat->Info().num_row_);
|
||||
this->PredictDMatrix(dmat, &out_preds->HostVector(), model, tree_begin,
|
||||
tree_end);
|
||||
}
|
||||
|
||||
template <typename Adapter>
|
||||
@ -362,7 +334,6 @@ class CPUPredictor : public Predictor {
|
||||
InitThreadTemp(nthread, model.learner_model_param->num_feature, &feat_vecs);
|
||||
const MetaInfo& info = p_fmat->Info();
|
||||
// number of valid trees
|
||||
ntree_limit *= model.learner_model_param->num_output_group;
|
||||
if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
|
||||
ntree_limit = static_cast<unsigned>(model.trees.size());
|
||||
}
|
||||
@ -398,7 +369,6 @@ class CPUPredictor : public Predictor {
|
||||
InitThreadTemp(nthread, model.learner_model_param->num_feature, &feat_vecs);
|
||||
const MetaInfo& info = p_fmat->Info();
|
||||
// number of valid trees
|
||||
ntree_limit *= model.learner_model_param->num_output_group;
|
||||
if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
|
||||
ntree_limit = static_cast<unsigned>(model.trees.size());
|
||||
}
|
||||
|
||||
@ -536,6 +536,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
const uint32_t BLOCK_THREADS = 256;
|
||||
size_t num_rows = batch.n_rows;
|
||||
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS));
|
||||
DeviceModel d_model;
|
||||
|
||||
bool use_shared = false;
|
||||
size_t entry_start = 0;
|
||||
@ -593,54 +594,27 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
}
|
||||
|
||||
void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts,
|
||||
const gbm::GBTreeModel& model, int tree_begin,
|
||||
unsigned ntree_limit = 0) const override {
|
||||
// This function is duplicated with CPU predictor PredictBatch, see comments in there.
|
||||
// FIXME(trivialfis): Remove the duplication.
|
||||
const gbm::GBTreeModel& model, uint32_t tree_begin,
|
||||
uint32_t tree_end = 0) const override {
|
||||
int device = generic_param_->gpu_id;
|
||||
CHECK_GE(device, 0) << "Set `gpu_id' to positive value for processing GPU data.";
|
||||
ConfigureDevice(device);
|
||||
|
||||
CHECK_EQ(tree_begin, 0);
|
||||
auto* out_preds = &predts->predictions;
|
||||
CHECK_GE(predts->version, tree_begin);
|
||||
|
||||
if (out_preds->Size() == 0 && dmat->Info().num_row_ != 0) {
|
||||
CHECK_EQ(predts->version, 0);
|
||||
}
|
||||
if (tree_end == 0) {
|
||||
tree_end = model.trees.size();
|
||||
}
|
||||
if (predts->version == 0) {
|
||||
// out_preds->Size() can be non-zero as it's initialized here before any tree is
|
||||
// built at the 0^th iterator.
|
||||
this->InitOutPredictions(dmat->Info(), out_preds, model);
|
||||
}
|
||||
|
||||
uint32_t const output_groups = model.learner_model_param->num_output_group;
|
||||
CHECK_NE(output_groups, 0);
|
||||
|
||||
uint32_t real_ntree_limit = ntree_limit * output_groups;
|
||||
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
|
||||
real_ntree_limit = static_cast<uint32_t>(model.trees.size());
|
||||
if (tree_end - tree_begin == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint32_t const end_version = (tree_begin + real_ntree_limit) / output_groups;
|
||||
|
||||
if (predts->version > end_version) {
|
||||
CHECK_NE(ntree_limit, 0);
|
||||
this->InitOutPredictions(dmat->Info(), out_preds, model);
|
||||
predts->version = 0;
|
||||
}
|
||||
uint32_t const beg_version = predts->version;
|
||||
CHECK_LE(beg_version, end_version);
|
||||
|
||||
if (beg_version < end_version) {
|
||||
this->DevicePredictInternal(dmat, out_preds, model,
|
||||
beg_version * output_groups,
|
||||
end_version * output_groups);
|
||||
}
|
||||
|
||||
uint32_t delta = end_version - beg_version;
|
||||
CHECK_LE(delta, model.trees.size());
|
||||
predts->Update(delta);
|
||||
|
||||
CHECK(out_preds->Size() == output_groups * dmat->Info().num_row_ ||
|
||||
out_preds->Size() == dmat->Info().num_row_);
|
||||
this->DevicePredictInternal(dmat, out_preds, model, tree_begin, tree_end);
|
||||
}
|
||||
|
||||
template <typename Adapter, typename Loader>
|
||||
@ -648,15 +622,12 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
const gbm::GBTreeModel &model, float,
|
||||
PredictionCacheEntry *out_preds,
|
||||
uint32_t tree_begin, uint32_t tree_end) const {
|
||||
auto max_shared_memory_bytes = dh::MaxSharedMemory(this->generic_param_->gpu_id);
|
||||
uint32_t const output_groups = model.learner_model_param->num_output_group;
|
||||
DeviceModel d_model;
|
||||
d_model.Init(model, tree_begin, tree_end, this->generic_param_->gpu_id);
|
||||
|
||||
auto m = dmlc::get<std::shared_ptr<Adapter>>(x);
|
||||
CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature)
|
||||
<< "Number of columns in data must equal to trained model.";
|
||||
CHECK_EQ(this->generic_param_->gpu_id, m->DeviceIdx())
|
||||
CHECK_EQ(dh::CurrentDevice(), m->DeviceIdx())
|
||||
<< "XGBoost is running on device: " << this->generic_param_->gpu_id << ", "
|
||||
<< "but data is on: " << m->DeviceIdx();
|
||||
if (p_m) {
|
||||
@ -667,12 +638,17 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
info.num_row_ = m->NumRows();
|
||||
this->InitOutPredictions(info, &(out_preds->predictions), model);
|
||||
}
|
||||
out_preds->predictions.SetDevice(m->DeviceIdx());
|
||||
|
||||
const uint32_t BLOCK_THREADS = 128;
|
||||
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(m->NumRows(), BLOCK_THREADS));
|
||||
|
||||
auto max_shared_memory_bytes = dh::MaxSharedMemory(m->DeviceIdx());
|
||||
size_t shared_memory_bytes =
|
||||
SharedMemoryBytes<BLOCK_THREADS>(m->NumColumns(), max_shared_memory_bytes);
|
||||
DeviceModel d_model;
|
||||
d_model.Init(model, tree_begin, tree_end, m->DeviceIdx());
|
||||
|
||||
bool use_shared = shared_memory_bytes != 0;
|
||||
size_t entry_start = 0;
|
||||
|
||||
@ -707,20 +683,17 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
|
||||
void PredictContribution(DMatrix* p_fmat,
|
||||
HostDeviceVector<bst_float>* out_contribs,
|
||||
const gbm::GBTreeModel& model, unsigned ntree_limit,
|
||||
const gbm::GBTreeModel& model, unsigned tree_end,
|
||||
std::vector<bst_float>*,
|
||||
bool approximate, int,
|
||||
unsigned) const override {
|
||||
if (approximate) {
|
||||
LOG(FATAL) << "Approximated contribution is not implemented in GPU Predictor.";
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
|
||||
out_contribs->SetDevice(generic_param_->gpu_id);
|
||||
uint32_t real_ntree_limit =
|
||||
ntree_limit * model.learner_model_param->num_output_group;
|
||||
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
|
||||
real_ntree_limit = static_cast<uint32_t>(model.trees.size());
|
||||
if (tree_end == 0 || tree_end > model.trees.size()) {
|
||||
tree_end = static_cast<uint32_t>(model.trees.size());
|
||||
}
|
||||
|
||||
const int ngroup = model.learner_model_param->num_output_group;
|
||||
@ -734,8 +707,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
auto phis = out_contribs->DeviceSpan();
|
||||
|
||||
dh::device_vector<gpu_treeshap::PathElement> device_paths;
|
||||
ExtractPaths(&device_paths, model, real_ntree_limit,
|
||||
generic_param_->gpu_id);
|
||||
ExtractPaths(&device_paths, model, tree_end, generic_param_->gpu_id);
|
||||
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
batch.data.SetDevice(generic_param_->gpu_id);
|
||||
batch.offset.SetDevice(generic_param_->gpu_id);
|
||||
@ -761,20 +733,17 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
void PredictInteractionContributions(DMatrix* p_fmat,
|
||||
HostDeviceVector<bst_float>* out_contribs,
|
||||
const gbm::GBTreeModel& model,
|
||||
unsigned ntree_limit,
|
||||
unsigned tree_end,
|
||||
std::vector<bst_float>*,
|
||||
bool approximate) const override {
|
||||
if (approximate) {
|
||||
LOG(FATAL) << "[Internal error]: " << __func__
|
||||
<< " approximate is not implemented in GPU Predictor.";
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
|
||||
out_contribs->SetDevice(generic_param_->gpu_id);
|
||||
uint32_t real_ntree_limit =
|
||||
ntree_limit * model.learner_model_param->num_output_group;
|
||||
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
|
||||
real_ntree_limit = static_cast<uint32_t>(model.trees.size());
|
||||
if (tree_end == 0 || tree_end > model.trees.size()) {
|
||||
tree_end = static_cast<uint32_t>(model.trees.size());
|
||||
}
|
||||
|
||||
const int ngroup = model.learner_model_param->num_output_group;
|
||||
@ -789,8 +758,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
auto phis = out_contribs->DeviceSpan();
|
||||
|
||||
dh::device_vector<gpu_treeshap::PathElement> device_paths;
|
||||
ExtractPaths(&device_paths, model, real_ntree_limit,
|
||||
generic_param_->gpu_id);
|
||||
ExtractPaths(&device_paths, model, tree_end, generic_param_->gpu_id);
|
||||
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
batch.data.SetDevice(generic_param_->gpu_id);
|
||||
batch.offset.SetDevice(generic_param_->gpu_id);
|
||||
@ -841,29 +809,28 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
<< " is not implemented in GPU Predictor.";
|
||||
}
|
||||
|
||||
void PredictLeaf(DMatrix* p_fmat, HostDeviceVector<bst_float>* predictions,
|
||||
const gbm::GBTreeModel& model,
|
||||
unsigned ntree_limit) const override {
|
||||
void PredictLeaf(DMatrix *p_fmat, HostDeviceVector<bst_float> *predictions,
|
||||
const gbm::GBTreeModel &model,
|
||||
unsigned tree_end) const override {
|
||||
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
|
||||
auto max_shared_memory_bytes = ConfigureDevice(generic_param_->gpu_id);
|
||||
|
||||
const MetaInfo& info = p_fmat->Info();
|
||||
constexpr uint32_t kBlockThreads = 128;
|
||||
size_t shared_memory_bytes =
|
||||
SharedMemoryBytes<kBlockThreads>(info.num_col_, max_shared_memory_bytes);
|
||||
size_t shared_memory_bytes = SharedMemoryBytes<kBlockThreads>(
|
||||
info.num_col_, max_shared_memory_bytes);
|
||||
bool use_shared = shared_memory_bytes != 0;
|
||||
bst_feature_t num_features = info.num_col_;
|
||||
bst_row_t num_rows = info.num_row_;
|
||||
size_t entry_start = 0;
|
||||
|
||||
uint32_t real_ntree_limit = ntree_limit * model.learner_model_param->num_output_group;
|
||||
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
|
||||
real_ntree_limit = static_cast<uint32_t>(model.trees.size());
|
||||
if (tree_end == 0 || tree_end > model.trees.size()) {
|
||||
tree_end = static_cast<uint32_t>(model.trees.size());
|
||||
}
|
||||
predictions->SetDevice(generic_param_->gpu_id);
|
||||
predictions->Resize(num_rows * real_ntree_limit);
|
||||
predictions->Resize(num_rows * tree_end);
|
||||
DeviceModel d_model;
|
||||
d_model.Init(model, 0, real_ntree_limit, this->generic_param_->gpu_id);
|
||||
d_model.Init(model, 0, tree_end, this->generic_param_->gpu_id);
|
||||
|
||||
if (p_fmat->PageExists<SparsePage>()) {
|
||||
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
|
||||
@ -34,6 +34,7 @@ dependencies:
|
||||
- llvmlite
|
||||
- pip:
|
||||
- shap
|
||||
- ipython # required by shap at import time.
|
||||
- guzzle_sphinx_theme
|
||||
- datatable
|
||||
- modin[all]
|
||||
|
||||
@ -51,6 +51,53 @@ TEST(GBTree, SelectTreeMethod) {
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
}
|
||||
|
||||
TEST(GBTree, PredictionCache) {
|
||||
size_t constexpr kRows = 100, kCols = 10;
|
||||
GenericParameter generic_param;
|
||||
generic_param.UpdateAllowUnknown(Args{});
|
||||
LearnerModelParam mparam;
|
||||
mparam.base_score = 0.5;
|
||||
mparam.num_feature = kCols;
|
||||
mparam.num_output_group = 1;
|
||||
|
||||
std::unique_ptr<GradientBooster> p_gbm {
|
||||
GradientBooster::Create("gbtree", &generic_param, &mparam)};
|
||||
auto& gbtree = dynamic_cast<gbm::GBTree&> (*p_gbm);
|
||||
|
||||
gbtree.Configure({{"tree_method", "hist"}});
|
||||
auto p_m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix();
|
||||
auto gpair = GenerateRandomGradients(kRows);
|
||||
PredictionCacheEntry out_predictions;
|
||||
gbtree.DoBoost(p_m.get(), &gpair, &out_predictions);
|
||||
|
||||
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 0);
|
||||
ASSERT_EQ(1, out_predictions.version);
|
||||
std::vector<float> first_iter = out_predictions.predictions.HostVector();
|
||||
// Add 1 more boosted round
|
||||
gbtree.DoBoost(p_m.get(), &gpair, &out_predictions);
|
||||
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 0);
|
||||
ASSERT_EQ(2, out_predictions.version);
|
||||
// Update the cache for all rounds
|
||||
out_predictions.version = 0;
|
||||
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 0);
|
||||
ASSERT_EQ(2, out_predictions.version);
|
||||
|
||||
gbtree.DoBoost(p_m.get(), &gpair, &out_predictions);
|
||||
// drop the cache.
|
||||
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 1, 2);
|
||||
ASSERT_EQ(0, out_predictions.version);
|
||||
// half open set [1, 3)
|
||||
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 1, 3);
|
||||
ASSERT_EQ(0, out_predictions.version);
|
||||
// iteration end
|
||||
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 2);
|
||||
ASSERT_EQ(2, out_predictions.version);
|
||||
// restart the cache when end iteration is smaller than cache version
|
||||
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 1);
|
||||
ASSERT_EQ(1, out_predictions.version);
|
||||
ASSERT_EQ(out_predictions.predictions.HostVector(), first_iter);
|
||||
}
|
||||
|
||||
TEST(GBTree, WrongUpdater) {
|
||||
size_t constexpr kRows = 17;
|
||||
size_t constexpr kCols = 15;
|
||||
|
||||
@ -32,7 +32,7 @@ TEST(CpuPredictor, Basic) {
|
||||
// Test predict batch
|
||||
PredictionCacheEntry out_predictions;
|
||||
cpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0);
|
||||
ASSERT_EQ(model.trees.size(), out_predictions.version);
|
||||
|
||||
std::vector<float>& out_predictions_h = out_predictions.predictions.HostVector();
|
||||
for (size_t i = 0; i < out_predictions.predictions.Size(); i++) {
|
||||
ASSERT_EQ(out_predictions_h[i], 1.5);
|
||||
@ -215,7 +215,7 @@ TEST(CpuPredictor, UpdatePredictionCache) {
|
||||
|
||||
PredictionCacheEntry out_predictions;
|
||||
// perform fair prediction on the same input data, should be equal to cached result
|
||||
gbm->PredictBatch(dmat.get(), &out_predictions, false, 0);
|
||||
gbm->PredictBatch(dmat.get(), &out_predictions, false, 0, 0);
|
||||
|
||||
std::vector<float> &out_predictions_h = out_predictions.predictions.HostVector();
|
||||
std::vector<float> &predtion_cache_from_train = predtion_cache.predictions.HostVector();
|
||||
|
||||
@ -45,7 +45,6 @@ TEST(GPUPredictor, Basic) {
|
||||
PredictionCacheEntry cpu_out_predictions;
|
||||
|
||||
gpu_predictor->PredictBatch(dmat.get(), &gpu_out_predictions, model, 0);
|
||||
ASSERT_EQ(model.trees.size(), gpu_out_predictions.version);
|
||||
cpu_predictor->PredictBatch(dmat.get(), &cpu_out_predictions, model, 0);
|
||||
|
||||
std::vector<float>& gpu_out_predictions_h = gpu_out_predictions.predictions.HostVector();
|
||||
|
||||
@ -64,10 +64,10 @@ void TestTrainingPrediction(size_t rows, size_t bins,
|
||||
}
|
||||
|
||||
HostDeviceVector<float> from_full;
|
||||
learner->Predict(p_full, false, &from_full);
|
||||
learner->Predict(p_full, false, &from_full, 0, 0);
|
||||
|
||||
HostDeviceVector<float> from_hist;
|
||||
learner->Predict(p_hist, false, &from_hist);
|
||||
learner->Predict(p_hist, false, &from_hist, 0, 0);
|
||||
|
||||
for (size_t i = 0; i < rows; ++i) {
|
||||
EXPECT_NEAR(from_hist.ConstHostVector()[i],
|
||||
@ -157,20 +157,20 @@ void TestPredictionWithLesserFeatures(std::string predictor_name) {
|
||||
learner->SaveConfig(&config);
|
||||
ASSERT_EQ(get<String>(config["learner"]["gradient_booster"]["gbtree_train_param"]["predictor"]), predictor_name);
|
||||
|
||||
learner->Predict(m_test, false, &prediction);
|
||||
learner->Predict(m_test, false, &prediction, 0, 0);
|
||||
ASSERT_EQ(prediction.Size(), kRows);
|
||||
|
||||
auto m_invalid = RandomDataGenerator(kRows, kTrainCols + 1, 0.5).GenerateDMatrix(false);
|
||||
ASSERT_THROW({learner->Predict(m_invalid, false, &prediction);}, dmlc::Error);
|
||||
ASSERT_THROW({learner->Predict(m_invalid, false, &prediction, 0, 0);}, dmlc::Error);
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
HostDeviceVector<float> from_cpu;
|
||||
learner->SetParam("predictor", "cpu_predictor");
|
||||
learner->Predict(m_test, false, &from_cpu);
|
||||
learner->Predict(m_test, false, &from_cpu, 0, 0);
|
||||
|
||||
HostDeviceVector<float> from_cuda;
|
||||
learner->SetParam("predictor", "gpu_predictor");
|
||||
learner->Predict(m_test, false, &from_cuda);
|
||||
learner->Predict(m_test, false, &from_cuda, 0, 0);
|
||||
|
||||
auto const& h_cpu = from_cpu.ConstHostVector();
|
||||
auto const& h_gpu = from_cuda.ConstHostVector();
|
||||
|
||||
@ -221,9 +221,10 @@ TEST(Learner, MultiThreadedPredict) {
|
||||
auto &entry = learner->GetThreadLocal().prediction_entry;
|
||||
HostDeviceVector<float> predictions;
|
||||
for (size_t iter = 0; iter < kIters; ++iter) {
|
||||
learner->Predict(p_data, false, &entry.predictions);
|
||||
learner->Predict(p_data, false, &predictions, 0, true); // leaf
|
||||
learner->Predict(p_data, false, &predictions, 0, false, true); // contribs
|
||||
learner->Predict(p_data, false, &entry.predictions, 0, 0);
|
||||
|
||||
learner->Predict(p_data, false, &predictions, 0, 0, false, true); // leaf
|
||||
learner->Predict(p_data, false, &predictions, 0, 0, false, false, true); // contribs
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@ -112,17 +112,24 @@ def _test_cupy_metainfo(DMatrixT):
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_cupy_training_with_sklearn():
|
||||
import cupy as cp
|
||||
|
||||
np.random.seed(1)
|
||||
cp.random.seed(1)
|
||||
X = cp.random.randn(50, 10, dtype='float32')
|
||||
y = (cp.random.randn(50, dtype='float32') > 0).astype('int8')
|
||||
X = cp.random.randn(50, 10, dtype="float32")
|
||||
y = (cp.random.randn(50, dtype="float32") > 0).astype("int8")
|
||||
weights = np.random.random(50) + 1
|
||||
cupy_weights = cp.array(weights)
|
||||
base_margin = np.random.random(50)
|
||||
cupy_base_margin = cp.array(base_margin)
|
||||
|
||||
clf = xgb.XGBClassifier(gpu_id=0, tree_method='gpu_hist', use_label_encoder=False)
|
||||
clf.fit(X, y, sample_weight=cupy_weights, base_margin=cupy_base_margin, eval_set=[(X, y)])
|
||||
clf = xgb.XGBClassifier(gpu_id=0, tree_method="gpu_hist", use_label_encoder=False)
|
||||
clf.fit(
|
||||
X,
|
||||
y,
|
||||
sample_weight=cupy_weights,
|
||||
base_margin=cupy_base_margin,
|
||||
eval_set=[(X, y)],
|
||||
)
|
||||
pred = clf.predict(X)
|
||||
assert np.array_equal(np.unique(pred), np.array([0, 1]))
|
||||
|
||||
|
||||
@ -16,13 +16,15 @@ if sys.platform.startswith("win"):
|
||||
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
||||
|
||||
sys.path.append("tests/python")
|
||||
from test_with_dask import run_empty_dmatrix_reg # noqa
|
||||
from test_with_dask import run_empty_dmatrix_cls # noqa
|
||||
from test_with_dask import _get_client_workers # noqa
|
||||
from test_with_dask import generate_array # noqa
|
||||
from test_with_dask import kCols as random_cols # noqa
|
||||
from test_with_dask import suppress # noqa
|
||||
import testing as tm # noqa
|
||||
from test_with_dask import run_empty_dmatrix_reg # noqa
|
||||
from test_with_dask import run_boost_from_prediction # noqa
|
||||
from test_with_dask import run_dask_classifier # noqa
|
||||
from test_with_dask import run_empty_dmatrix_cls # noqa
|
||||
from test_with_dask import _get_client_workers # noqa
|
||||
from test_with_dask import generate_array # noqa
|
||||
from test_with_dask import kCols as random_cols # noqa
|
||||
from test_with_dask import suppress # noqa
|
||||
import testing as tm # noqa
|
||||
|
||||
|
||||
try:
|
||||
@ -132,9 +134,9 @@ def run_gpu_hist(
|
||||
num_rounds: int,
|
||||
dataset: tm.TestDataset,
|
||||
DMatrixT: Type,
|
||||
client: Client
|
||||
client: Client,
|
||||
) -> None:
|
||||
params['tree_method'] = 'gpu_hist'
|
||||
params["tree_method"] = "gpu_hist"
|
||||
params = dataset.set_params(params)
|
||||
# It doesn't make sense to distribute a completely
|
||||
# empty dataset.
|
||||
@ -143,26 +145,40 @@ def run_gpu_hist(
|
||||
|
||||
chunk = 128
|
||||
X = to_cp(dataset.X, DMatrixT)
|
||||
X = da.from_array(X,
|
||||
chunks=(chunk, dataset.X.shape[1]))
|
||||
X = da.from_array(X, chunks=(chunk, dataset.X.shape[1]))
|
||||
y = to_cp(dataset.y, DMatrixT)
|
||||
y = da.from_array(y, chunks=(chunk, ))
|
||||
y = da.from_array(y, chunks=(chunk,))
|
||||
if dataset.w is not None:
|
||||
w = to_cp(dataset.w, DMatrixT)
|
||||
w = da.from_array(w, chunks=(chunk, ))
|
||||
w = da.from_array(w, chunks=(chunk,))
|
||||
else:
|
||||
w = None
|
||||
|
||||
if DMatrixT is dxgb.DaskDeviceQuantileDMatrix:
|
||||
m = DMatrixT(client, data=X, label=y, weight=w,
|
||||
max_bin=params.get('max_bin', 256))
|
||||
m = DMatrixT(
|
||||
client, data=X, label=y, weight=w, max_bin=params.get("max_bin", 256)
|
||||
)
|
||||
else:
|
||||
m = DMatrixT(client, data=X, label=y, weight=w)
|
||||
history = dxgb.train(client, params=params, dtrain=m,
|
||||
num_boost_round=num_rounds,
|
||||
evals=[(m, 'train')])['history']
|
||||
history = dxgb.train(
|
||||
client,
|
||||
params=params,
|
||||
dtrain=m,
|
||||
num_boost_round=num_rounds,
|
||||
evals=[(m, "train")],
|
||||
)["history"]
|
||||
note(history)
|
||||
assert tm.non_increasing(history['train'][dataset.metric])
|
||||
assert tm.non_increasing(history["train"][dataset.metric])
|
||||
|
||||
|
||||
def test_boost_from_prediction(local_cuda_cluster: LocalCUDACluster) -> None:
|
||||
import cudf
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
with Client(local_cuda_cluster) as client:
|
||||
X_, y_ = load_breast_cancer(return_X_y=True)
|
||||
X = dd.from_array(X_, chunksize=100).map_partitions(cudf.from_pandas)
|
||||
y = dd.from_array(y_, chunksize=100).map_partitions(cudf.from_pandas)
|
||||
run_boost_from_prediction(X, y, "gpu_hist", client)
|
||||
|
||||
|
||||
class TestDistributedGPU:
|
||||
@ -246,6 +262,20 @@ class TestDistributedGPU:
|
||||
dump = booster.get_dump(dump_format='json')
|
||||
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.skipif(**tm.no_dask_cuda())
|
||||
@pytest.mark.parametrize("model", ["boosting"])
|
||||
def test_dask_classifier(self, model, local_cuda_cluster: LocalCUDACluster) -> None:
|
||||
import dask_cudf
|
||||
with Client(local_cuda_cluster) as client:
|
||||
X_, y_, w_ = generate_array(with_weights=True)
|
||||
y_ = (y_ * 10).astype(np.int32)
|
||||
X = dask_cudf.from_dask_dataframe(dd.from_dask_array(X_))
|
||||
y = dask_cudf.from_dask_dataframe(dd.from_dask_array(y_))
|
||||
w = dask_cudf.from_dask_dataframe(dd.from_dask_array(w_))
|
||||
run_dask_classifier(X, y, w, model, client)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.skipif(**tm.no_dask_cuda())
|
||||
@pytest.mark.mgpu
|
||||
|
||||
@ -434,7 +434,13 @@ class TestModels:
|
||||
booster[...:end] = booster
|
||||
|
||||
sliced_0 = booster[1:3]
|
||||
np.testing.assert_allclose(
|
||||
booster.predict(dtrain, iteration_range=(1, 3)), sliced_0.predict(dtrain)
|
||||
)
|
||||
sliced_1 = booster[3:7]
|
||||
np.testing.assert_allclose(
|
||||
booster.predict(dtrain, iteration_range=(3, 7)), sliced_1.predict(dtrain)
|
||||
)
|
||||
|
||||
predt_0 = sliced_0.predict(dtrain, output_margin=True)
|
||||
predt_1 = sliced_1.predict(dtrain, output_margin=True)
|
||||
|
||||
@ -47,30 +47,27 @@ def run_predict_leaf(predictor):
|
||||
empty_leaf = booster.predict(empty, pred_leaf=True)
|
||||
assert empty_leaf.shape[0] == 0
|
||||
|
||||
leaf = booster.predict(m, pred_leaf=True)
|
||||
leaf = booster.predict(m, pred_leaf=True, strict_shape=True)
|
||||
assert leaf.shape[0] == rows
|
||||
assert leaf.shape[1] == classes * num_parallel_tree * num_boost_round
|
||||
assert leaf.shape[1] == num_boost_round
|
||||
assert leaf.shape[2] == classes
|
||||
assert leaf.shape[3] == num_parallel_tree
|
||||
|
||||
for i in range(rows):
|
||||
row = leaf[i, ...]
|
||||
for j in range(num_boost_round):
|
||||
start = classes * num_parallel_tree * j
|
||||
end = classes * num_parallel_tree * (j + 1)
|
||||
layer = row[start: end]
|
||||
for c in range(classes):
|
||||
tree_group = layer[c * num_parallel_tree: (c + 1) * num_parallel_tree]
|
||||
for k in range(classes):
|
||||
tree_group = leaf[i, j, k, :]
|
||||
assert tree_group.shape[0] == num_parallel_tree
|
||||
# no subsampling so tree in same forest should output same
|
||||
# leaf.
|
||||
# No sampling, all trees within forest are the same
|
||||
assert np.all(tree_group == tree_group[0])
|
||||
|
||||
ntree_limit = 2
|
||||
sliced = booster.predict(
|
||||
m, pred_leaf=True, ntree_limit=num_parallel_tree * ntree_limit
|
||||
m, pred_leaf=True, ntree_limit=num_parallel_tree * ntree_limit, strict_shape=True
|
||||
)
|
||||
first = sliced[0, ...]
|
||||
|
||||
assert first.shape[0] == classes * num_parallel_tree * ntree_limit
|
||||
assert np.prod(first.shape) == classes * num_parallel_tree * ntree_limit
|
||||
return leaf
|
||||
|
||||
|
||||
@ -78,6 +75,23 @@ def test_predict_leaf():
|
||||
run_predict_leaf('cpu_predictor')
|
||||
|
||||
|
||||
def test_predict_shape():
|
||||
from sklearn.datasets import load_boston
|
||||
X, y = load_boston(return_X_y=True)
|
||||
reg = xgb.XGBRegressor(n_estimators=1)
|
||||
reg.fit(X, y)
|
||||
predt = reg.get_booster().predict(xgb.DMatrix(X), strict_shape=True)
|
||||
assert len(predt.shape) == 2
|
||||
assert predt.shape[0] == X.shape[0]
|
||||
assert predt.shape[1] == 1
|
||||
|
||||
contrib = reg.get_booster().predict(
|
||||
xgb.DMatrix(X), pred_contribs=True, strict_shape=True
|
||||
)
|
||||
assert len(contrib.shape) == 3
|
||||
assert contrib.shape[1] == 1
|
||||
|
||||
|
||||
class TestInplacePredict:
|
||||
'''Tests for running inplace prediction'''
|
||||
@classmethod
|
||||
@ -92,8 +106,7 @@ class TestInplacePredict:
|
||||
|
||||
dtrain = xgb.DMatrix(cls.X, cls.y)
|
||||
|
||||
cls.booster = xgb.train({'tree_method': 'hist'},
|
||||
dtrain, num_boost_round=10)
|
||||
cls.booster = xgb.train({'tree_method': 'hist'}, dtrain, num_boost_round=10)
|
||||
|
||||
cls.test = xgb.DMatrix(cls.X[:10, ...])
|
||||
|
||||
|
||||
@ -159,12 +159,9 @@ def test_dask_predict_shape_infer(client: "Client") -> None:
|
||||
assert prediction.shape[1] == 3
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tree_method", ["hist", "approx"])
|
||||
def test_boost_from_prediction(tree_method: str, client: "Client") -> None:
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
X_, y_ = load_breast_cancer(return_X_y=True)
|
||||
|
||||
X, y = dd.from_array(X_, chunksize=100), dd.from_array(y_, chunksize=100)
|
||||
def run_boost_from_prediction(
|
||||
X: xgb.dask._DaskCollection, y: xgb.dask._DaskCollection, tree_method: str, client: "Client"
|
||||
) -> None:
|
||||
model_0 = xgb.dask.DaskXGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=4,
|
||||
tree_method=tree_method)
|
||||
@ -202,6 +199,30 @@ def test_boost_from_prediction(tree_method: str, client: "Client") -> None:
|
||||
assert margined_res[i] < unmargined_res[i]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tree_method", ["hist", "approx"])
|
||||
def test_boost_from_prediction(tree_method: str, client: "Client") -> None:
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
X_, y_ = load_breast_cancer(return_X_y=True)
|
||||
X, y = dd.from_array(X_, chunksize=100), dd.from_array(y_, chunksize=100)
|
||||
run_boost_from_prediction(X, y, tree_method, client)
|
||||
|
||||
|
||||
def test_inplace_predict(client: "Client") -> None:
|
||||
from sklearn.datasets import load_boston
|
||||
X_, y_ = load_boston(return_X_y=True)
|
||||
X, y = dd.from_array(X_, chunksize=32), dd.from_array(y_, chunksize=32)
|
||||
reg = xgb.dask.DaskXGBRegressor(n_estimators=4).fit(X, y)
|
||||
booster = reg.get_booster()
|
||||
base_margin = y
|
||||
|
||||
inplace = xgb.dask.inplace_predict(
|
||||
client, booster, X, base_margin=base_margin
|
||||
).compute()
|
||||
Xy = xgb.dask.DaskDMatrix(client, X, base_margin=base_margin)
|
||||
copied = xgb.dask.predict(client, booster, Xy).compute()
|
||||
np.testing.assert_allclose(inplace, copied)
|
||||
|
||||
|
||||
def test_dask_missing_value_reg(client: "Client") -> None:
|
||||
X_0 = np.ones((20 // 2, kCols))
|
||||
X_1 = np.zeros((20 // 2, kCols))
|
||||
@ -288,10 +309,13 @@ def test_dask_regressor(model: str, client: "Client") -> None:
|
||||
assert forest == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["boosting", "rf"])
|
||||
def test_dask_classifier(model: str, client: "Client") -> None:
|
||||
X, y, w = generate_array(with_weights=True)
|
||||
y = (y * 10).astype(np.int32)
|
||||
def run_dask_classifier(
|
||||
X: xgb.dask._DaskCollection,
|
||||
y: xgb.dask._DaskCollection,
|
||||
w: xgb.dask._DaskCollection,
|
||||
model: str,
|
||||
client: "Client",
|
||||
) -> None:
|
||||
if model == "boosting":
|
||||
classifier = xgb.dask.DaskXGBClassifier(
|
||||
verbosity=1, n_estimators=2, eval_metric="merror"
|
||||
@ -306,14 +330,13 @@ def test_dask_classifier(model: str, client: "Client") -> None:
|
||||
|
||||
classifier.client = client
|
||||
classifier.fit(X, y, sample_weight=w, eval_set=[(X, y)])
|
||||
prediction = classifier.predict(X)
|
||||
prediction = classifier.predict(X).compute()
|
||||
|
||||
assert prediction.ndim == 1
|
||||
assert prediction.shape[0] == kRows
|
||||
|
||||
history = classifier.evals_result()
|
||||
|
||||
assert isinstance(prediction, da.Array)
|
||||
assert isinstance(history, dict)
|
||||
|
||||
assert list(history.keys())[0] == "validation_0"
|
||||
@ -332,7 +355,7 @@ def test_dask_classifier(model: str, client: "Client") -> None:
|
||||
assert forest == 2
|
||||
|
||||
# Test .predict_proba()
|
||||
probas = classifier.predict_proba(X)
|
||||
probas = classifier.predict_proba(X).compute()
|
||||
assert classifier.n_classes_ == 10
|
||||
assert probas.ndim == 2
|
||||
assert probas.shape[0] == kRows
|
||||
@ -341,18 +364,33 @@ def test_dask_classifier(model: str, client: "Client") -> None:
|
||||
cls_booster = classifier.get_booster()
|
||||
single_node_proba = cls_booster.inplace_predict(X.compute())
|
||||
|
||||
np.testing.assert_allclose(single_node_proba, probas.compute())
|
||||
# test shared by CPU and GPU
|
||||
if isinstance(single_node_proba, np.ndarray):
|
||||
np.testing.assert_allclose(single_node_proba, probas)
|
||||
else:
|
||||
import cupy
|
||||
cupy.testing.assert_allclose(single_node_proba, probas)
|
||||
|
||||
# Test with dataframe.
|
||||
X_d = dd.from_dask_array(X)
|
||||
y_d = dd.from_dask_array(y)
|
||||
classifier.fit(X_d, y_d)
|
||||
# Test with dataframe, not shared with GPU as cupy doesn't work well with da.unique.
|
||||
if isinstance(X, da.Array):
|
||||
X_d: dd.DataFrame = X.to_dask_dataframe()
|
||||
|
||||
assert classifier.n_classes_ == 10
|
||||
prediction = classifier.predict(X_d).compute()
|
||||
assert classifier.n_classes_ == 10
|
||||
prediction_df = classifier.predict(X_d).compute()
|
||||
|
||||
assert prediction.ndim == 1
|
||||
assert prediction.shape[0] == kRows
|
||||
assert prediction_df.ndim == 1
|
||||
assert prediction_df.shape[0] == kRows
|
||||
np.testing.assert_allclose(prediction_df, prediction)
|
||||
|
||||
probas = classifier.predict_proba(X).compute()
|
||||
np.testing.assert_allclose(single_node_proba, probas)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["boosting", "rf"])
|
||||
def test_dask_classifier(model: str, client: "Client") -> None:
|
||||
X, y, w = generate_array(with_weights=True)
|
||||
y = (y * 10).astype(np.int32)
|
||||
run_dask_classifier(X, y, w, model, client)
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
@ -913,9 +951,9 @@ class TestWithDask:
|
||||
train = xgb.dask.DaskDMatrix(client, dX, dy)
|
||||
|
||||
dX = dd.from_array(X)
|
||||
dX = client.persist(dX, workers={dX: workers[1]})
|
||||
dX = client.persist(dX, workers=workers[1])
|
||||
dy = dd.from_array(y)
|
||||
dy = client.persist(dy, workers={dy: workers[1]})
|
||||
dy = client.persist(dy, workers=workers[1])
|
||||
valid = xgb.dask.DaskDMatrix(client, dX, dy)
|
||||
|
||||
merged = xgb.dask._get_workers_from_data(train, evals=[(valid, 'Valid')])
|
||||
@ -1060,6 +1098,16 @@ class TestWithDask:
|
||||
assert_shape(shap.shape)
|
||||
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5)
|
||||
|
||||
X = dd.from_dask_array(X).repartition(npartitions=32)
|
||||
y = dd.from_dask_array(y).repartition(npartitions=32)
|
||||
shap_df = xgb.dask.predict(
|
||||
client, booster, X, pred_contribs=True, validate_features=False
|
||||
).compute()
|
||||
assert_shape(shap_df.shape)
|
||||
assert np.allclose(
|
||||
np.sum(shap_df, axis=len(shap_df.shape) - 1), margin, 1e-5, 1e-5
|
||||
)
|
||||
|
||||
def run_shap_cls_sklearn(self, X: Any, y: Any, client: "Client") -> None:
|
||||
X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32)
|
||||
cls = xgb.dask.DaskXGBClassifier(n_estimators=4)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user