[breaking] Add prediction fucntion for DMatrix and use inplace predict for dask. (#6668)

* Add a new API function for predicting on `DMatrix`.  This function aligns
with rest of the `XGBoosterPredictFrom*` functions on semantic of function
arguments.
* Purge `ntree_limit` from libxgboost, use iteration instead.
* [dask] Use `inplace_predict` by default for dask sklearn models.
* [dask] Run prediction shape inference on worker instead of client.

The breaking change is in the Python sklearn `apply` function, I made it to be
consistent with other prediction functions where `best_iteration` is used by
default.
This commit is contained in:
Jiaming Yuan 2021-02-08 18:26:32 +08:00 committed by GitHub
parent dbb5208a0a
commit 4656b09d5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 1134 additions and 604 deletions

View File

@ -704,8 +704,9 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle,
const char *evnames[],
bst_ulong len,
const char **out_result);
/*!
* \brief make prediction based on dmat
* \brief make prediction based on dmat (deprecated, use `XGBoosterPredictFromDMatrix` instead)
* \param handle handle
* \param dmat data matrix
* \param option_mask bit-mask of options taken in prediction, possible values
@ -734,6 +735,165 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
int training,
bst_ulong *out_len,
const float **out_result);
/*!
* \brief Make prediction from DMatrix, replacing `XGBoosterPredict`.
*
* \param handle Booster handle
* \param dmat DMatrix handle
* \param c_json_config String encoded predict configuration in JSON format.
*
* "type": [0, 5]
* 0: normal prediction
* 1: output margin
* 2: predict contribution
* 3: predict approxmated contribution
* 4: predict feature interaction
* 5: predict leaf
* "training": bool
* Whether the prediction function is used as part of a training loop. **Not used
* for inplace prediction**.
*
* Prediction can be run in 2 scenarios:
* 1. Given data matrix X, obtain prediction y_pred from the model.
* 2. Obtain the prediction for computing gradients. For example, DART booster performs dropout
* during training, and the prediction result will be different from the one obtained by normal
* inference step due to dropped trees.
* Set training=false for the first scenario. Set training=true for the second
* scenario. The second scenario applies when you are defining a custom objective
* function.
* "iteration_begin": int
* Beginning iteration of prediction.
* "iteration_end": int
* End iteration of prediction. Set to 0 this will become the size of tree model.
* "strict_shape": bool
* Whether should we reshape the output with stricter rules. If set to true,
* normal/margin/contrib/interaction predict will output consistent shape
* disregarding the use of multi-class model, and leaf prediction will output 4-dim
* array representing: (n_samples, n_iterations, n_classes, n_trees_in_forest)
*
* Run a normal prediction with strict output shape, 2 dim for softprob , 1 dim for others.
* \code
* {
* "type": 0,
* "training": False,
* "iteration_begin": 0,
* "iteration_end": 0,
* "strict_shape": true,
* }
* \endcode
*
* \param out_shape Shape of output prediction (copy before use).
* \param out_dim Dimension of output prediction.
* \param out_result Buffer storing prediction value (copy before use).
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle,
DMatrixHandle dmat,
char const* c_json_config,
bst_ulong const **out_shape,
bst_ulong *out_dim,
float const **out_result);
/*
* \brief Inplace prediction from CPU dense matrix.
*
* \param handle Booster handle.
* \param values JSON encoded __array_interface__ to values.
* \param c_json_config See `XGBoosterPredictFromDMatrix` for more info.
*
* Additional fields for inplace prediction are:
* "missing": float
*
* \param m An optional (NULL if not available) proxy DMatrix instance
* storing meta info.
*
* \param out_shape See `XGBoosterPredictFromDMatrix` for more info.
* \param out_dim See `XGBoosterPredictFromDMatrix` for more info.
* \param out_result See `XGBoosterPredictFromDMatrix` for more info.
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle,
char const *values,
char const *c_json_config,
DMatrixHandle m,
bst_ulong const **out_shape,
bst_ulong *out_dim,
const float **out_result);
/*
* \brief Inplace prediction from CPU CSR matrix.
*
* \param handle Booster handle.
* \param indptr JSON encoded __array_interface__ to row pointer in CSR.
* \param indices JSON encoded __array_interface__ to column indices in CSR.
* \param values JSON encoded __array_interface__ to values in CSR..
* \param ncol Number of features in data.
* \param c_json_config See `XGBoosterPredictFromDMatrix` for more info.
* Additional fields for inplace prediction are:
* "missing": float
*
* \param m An optional (NULL if not available) proxy DMatrix instance
* storing meta info.
*
* \param out_shape See `XGBoosterPredictFromDMatrix` for more info.
* \param out_dim See `XGBoosterPredictFromDMatrix` for more info.
* \param out_result See `XGBoosterPredictFromDMatrix` for more info.
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, char const *indptr,
char const *indices, char const *values,
bst_ulong ncol,
char const *c_json_config, DMatrixHandle m,
bst_ulong const **out_shape,
bst_ulong *out_dim,
const float **out_result);
/*
* \brief Inplace prediction from CUDA Dense matrix (cupy in Python).
*
* \param handle Booster handle
* \param values JSON encoded __cuda_array_interface__ to values.
* \param c_json_config See `XGBoosterPredictFromDMatrix` for more info.
* Additional fields for inplace prediction are:
* "missing": float
*
* \param m An optional (NULL if not available) proxy DMatrix instance
* storing meta info.
* \param out_shape See `XGBoosterPredictFromDMatrix` for more info.
* \param out_dim See `XGBoosterPredictFromDMatrix` for more info.
* \param out_result See `XGBoosterPredictFromDMatrix` for more info.
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterPredictFromCudaArray(
BoosterHandle handle, char const *values, char const *c_json_config,
DMatrixHandle m, bst_ulong const **out_shape, bst_ulong *out_dim,
const float **out_result);
/*
* \brief Inplace prediction from CUDA dense dataframe (cuDF in Python).
*
* \param handle Booster handle
* \param values List of __cuda_array_interface__ for all columns encoded in JSON list.
* \param c_json_config See `XGBoosterPredictFromDMatrix` for more info.
* Additional fields for inplace prediction are:
* "missing": float
*
* \param m An optional (NULL if not available) proxy DMatrix instance
* storing meta info.
* \param out_shape See `XGBoosterPredictFromDMatrix` for more info.
* \param out_dim See `XGBoosterPredictFromDMatrix` for more info.
* \param out_result See `XGBoosterPredictFromDMatrix` for more info.
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterPredictFromCudaColumnar(
BoosterHandle handle, char const *values, char const *c_json_config,
DMatrixHandle m, bst_ulong const **out_shape, bst_ulong *out_dim,
const float **out_result);
/*
* ========================== Begin Serialization APIs =========================

View File

@ -63,7 +63,7 @@ class GradientBooster : public Model, public Configurable {
/*!
* \brief Slice a model using boosting index. The slice m:n indicates taking all trees
* that were fit during the boosting rounds m, (m+1), (m+2), ..., (n-1).
* \param layer_begin Begining of boosted tree layer used for prediction.
* \param layer_begin Beginning of boosted tree layer used for prediction.
* \param layer_end End of booster layer. 0 means do not limit trees.
* \param out Output gradient booster
*/
@ -99,15 +99,14 @@ class GradientBooster : public Model, public Configurable {
* \param out_preds output vector to hold the predictions
* \param training Whether the prediction value is used for training. For dart booster
* drop out is performed during training.
* \param ntree_limit limit the number of trees used in prediction,
* when it equals 0, this means we do not limit
* number of trees, this parameter is only valid
* for gbtree, but not for gblinear
* \param layer_begin Beginning of boosted tree layer used for prediction.
* \param layer_end End of booster layer. 0 means do not limit trees.
*/
virtual void PredictBatch(DMatrix* dmat,
PredictionCacheEntry* out_preds,
bool training,
unsigned ntree_limit = 0) = 0;
unsigned layer_begin,
unsigned layer_end) = 0;
/*!
* \brief Inplace prediction.
@ -115,7 +114,7 @@ class GradientBooster : public Model, public Configurable {
* \param x A type erased data adapter.
* \param missing Missing value in the data.
* \param [in,out] out_preds The output preds.
* \param layer_begin (Optional) Begining of boosted tree layer used for prediction.
* \param layer_begin (Optional) Beginning of boosted tree layer used for prediction.
* \param layer_end (Optional) End of booster layer. 0 means do not limit trees.
*/
virtual void InplacePredict(dmlc::any const &, std::shared_ptr<DMatrix>, float,
@ -132,44 +131,45 @@ class GradientBooster : public Model, public Configurable {
*
* \param inst the instance you want to predict
* \param out_preds output vector to hold the predictions
* \param ntree_limit limit the number of trees used in prediction
* \param layer_begin Beginning of boosted tree layer used for prediction.
* \param layer_end End of booster layer. 0 means do not limit trees.
* \sa Predict
*/
virtual void PredictInstance(const SparsePage::Inst& inst,
std::vector<bst_float>* out_preds,
unsigned ntree_limit = 0) = 0;
unsigned layer_begin, unsigned layer_end) = 0;
/*!
* \brief predict the leaf index of each tree, the output will be nsample * ntree vector
* this is only valid in gbtree predictor
* \param dmat feature matrix
* \param out_preds output vector to hold the predictions
* \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means
* we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear
* \param layer_begin Beginning of boosted tree layer used for prediction.
* \param layer_end End of booster layer. 0 means do not limit trees.
*/
virtual void PredictLeaf(DMatrix* dmat,
HostDeviceVector<bst_float>* out_preds,
unsigned ntree_limit = 0) = 0;
virtual void PredictLeaf(DMatrix *dmat,
HostDeviceVector<bst_float> *out_preds,
unsigned layer_begin, unsigned layer_end) = 0;
/*!
* \brief feature contributions to individual predictions; the output will be a vector
* of length (nfeats + 1) * num_output_group * nsample, arranged in that order
* \param dmat feature matrix
* \param out_contribs output vector to hold the contributions
* \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means
* we do not limit number of trees
* \param layer_begin Beginning of boosted tree layer used for prediction.
* \param layer_end End of booster layer. 0 means do not limit trees.
* \param approximate use a faster (inconsistent) approximation of SHAP values
* \param condition condition on the condition_feature (0=no, -1=cond off, 1=cond on).
* \param condition_feature feature to condition on (i.e. fix) during calculations
*/
virtual void PredictContribution(DMatrix* dmat,
HostDeviceVector<bst_float>* out_contribs,
unsigned ntree_limit = 0,
unsigned layer_begin, unsigned layer_end,
bool approximate = false, int condition = 0,
unsigned condition_feature = 0) = 0;
virtual void PredictInteractionContributions(DMatrix* dmat,
HostDeviceVector<bst_float>* out_contribs,
unsigned ntree_limit, bool approximate) = 0;
virtual void PredictInteractionContributions(
DMatrix *dmat, HostDeviceVector<bst_float> *out_contribs,
unsigned layer_begin, unsigned layer_end, bool approximate) = 0;
/*!
* \brief dump the model in the requested format

View File

@ -113,8 +113,8 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
* \param data input data
* \param output_margin whether to only predict margin value instead of transformed prediction
* \param out_preds output vector that stores the prediction
* \param ntree_limit limit number of trees used for boosted tree
* predictor, when it equals 0, this means we are using all the trees
* \param layer_begin Beginning of boosted tree layer used for prediction.
* \param layer_end End of booster layer. 0 means do not limit trees.
* \param training Whether the prediction result is used for training
* \param pred_leaf whether to only predict the leaf index of each tree in a boosted tree predictor
* \param pred_contribs whether to only predict the feature contributions
@ -124,7 +124,8 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
virtual void Predict(std::shared_ptr<DMatrix> data,
bool output_margin,
HostDeviceVector<bst_float> *out_preds,
unsigned ntree_limit = 0,
unsigned layer_begin,
unsigned layer_end,
bool training = false,
bool pred_leaf = false,
bool pred_contribs = false,
@ -140,7 +141,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
* \param type Prediction type.
* \param missing Missing value in the data.
* \param [in,out] out_preds Pointer to output prediction vector.
* \param layer_begin Begining of boosted tree layer used for prediction.
* \param layer_begin Beginning of boosted tree layer used for prediction.
* \param layer_end End of booster layer. 0 means do not limit trees.
*/
virtual void InplacePredict(dmlc::any const &x,

View File

@ -127,12 +127,11 @@ class Predictor {
* \param [in,out] out_preds The output preds.
* \param model The model to predict from.
* \param tree_begin The tree begin index.
* \param ntree_limit (Optional) The ntree limit. 0 means do not
* limit trees.
* \param tree_end The tree end index.
*/
virtual void PredictBatch(DMatrix* dmat, PredictionCacheEntry* out_preds,
const gbm::GBTreeModel& model, int tree_begin,
uint32_t const ntree_limit = 0) const = 0;
const gbm::GBTreeModel& model, uint32_t tree_begin,
uint32_t tree_end = 0) const = 0;
/**
* \brief Inplace prediction.
@ -140,7 +139,7 @@ class Predictor {
* \param model The model to predict from.
* \param missing Missing value in the data.
* \param [in,out] out_preds The output preds.
* \param tree_begin (Optional) Begining of boosted trees used for prediction.
* \param tree_begin (Optional) Beginning of boosted trees used for prediction.
* \param tree_end (Optional) End of booster trees. 0 means do not limit trees.
*
* \return True if the data can be handled by current predictor, false otherwise.
@ -159,13 +158,13 @@ class Predictor {
* \param inst The instance to predict.
* \param [in,out] out_preds The output preds.
* \param model The model to predict from
* \param ntree_limit (Optional) The ntree limit.
* \param tree_end (Optional) The tree end index.
*/
virtual void PredictInstance(const SparsePage::Inst& inst,
std::vector<bst_float>* out_preds,
const gbm::GBTreeModel& model,
unsigned ntree_limit = 0) const = 0;
unsigned tree_end = 0) const = 0;
/**
* \brief predict the leaf index of each tree, the output will be nsample *
@ -174,18 +173,14 @@ class Predictor {
* \param [in,out] dmat The input feature matrix.
* \param [in,out] out_preds The output preds.
* \param model Model to make predictions from.
* \param ntree_limit (Optional) The ntree limit.
* \param tree_end (Optional) The tree end index.
*/
virtual void PredictLeaf(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
const gbm::GBTreeModel& model,
unsigned ntree_limit = 0) const = 0;
unsigned tree_end = 0) const = 0;
/**
* \fn virtual void Predictor::PredictContribution( DMatrix* dmat,
* std::vector<bst_float>* out_contribs, const gbm::GBTreeModel& model,
* unsigned ntree_limit = 0) = 0;
*
* \brief feature contributions to individual predictions; the output will be
* a vector of length (nfeats + 1) * num_output_group * nsample, arranged in
* that order.
@ -193,7 +188,7 @@ class Predictor {
* \param [in,out] dmat The input feature matrix.
* \param [in,out] out_contribs The output feature contribs.
* \param model Model to make predictions from.
* \param ntree_limit (Optional) The ntree limit.
* \param tree_end The tree end index.
* \param tree_weights (Optional) Weights to multiply each tree by.
* \param approximate Use fast approximate algorithm.
* \param condition Condition on the condition_feature (0=no, -1=cond off, 1=cond on).
@ -203,7 +198,7 @@ class Predictor {
virtual void PredictContribution(DMatrix* dmat,
HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model,
unsigned ntree_limit = 0,
unsigned tree_end = 0,
std::vector<bst_float>* tree_weights = nullptr,
bool approximate = false,
int condition = 0,
@ -212,7 +207,7 @@ class Predictor {
virtual void PredictInteractionContributions(DMatrix* dmat,
HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model,
unsigned ntree_limit = 0,
unsigned tree_end = 0,
std::vector<bst_float>* tree_weights = nullptr,
bool approximate = false) const = 0;

View File

@ -96,6 +96,24 @@ def from_cstr_to_pystr(data, length):
return res
def _convert_ntree_limit(booster, ntree_limit, iteration_range):
if ntree_limit is not None and ntree_limit != 0:
warnings.warn(
"ntree_limit is deprecated, use `iteration_range` or model "
"slicing instead.",
UserWarning
)
if iteration_range is not None and iteration_range[1] != 0:
raise ValueError(
"Only one of `iteration_range` and `ntree_limit` can be non zero."
)
num_parallel_tree, num_groups = _get_booster_layer_trees(booster)
num_parallel_tree = max([num_parallel_tree, 1])
num_groups = max([num_groups, 1])
iteration_range = (0, ntree_limit // num_parallel_tree)
return iteration_range
def _expect(expectations, got):
"""Translate input error into string.
@ -1111,6 +1129,34 @@ Objective = Callable[[np.ndarray, DMatrix], Tuple[np.ndarray, np.ndarray]]
Metric = Callable[[np.ndarray, DMatrix], Tuple[str, float]]
def _get_booster_layer_trees(model: "Booster") -> Tuple[int, int]:
"""Get number of trees added to booster per-iteration. This function will be removed
once `best_ntree_limit` is dropped in favor of `best_iteration`. Returns
`num_parallel_tree` and `num_groups`.
"""
config = json.loads(model.save_config())
booster = config["learner"]["gradient_booster"]["name"]
if booster == "gblinear":
num_parallel_tree = 0
elif booster == "dart":
num_parallel_tree = int(
config["learner"]["gradient_booster"]["gbtree"]["gbtree_train_param"][
"num_parallel_tree"
]
)
elif booster == "gbtree":
num_parallel_tree = int(
config["learner"]["gradient_booster"]["gbtree_train_param"][
"num_parallel_tree"
]
)
else:
raise ValueError(f"Unknown booster: {booster}")
num_groups = int(config["learner"]["learner_model_param"]["num_class"])
return num_parallel_tree, num_groups
class Booster(object):
# pylint: disable=too-many-public-methods
"""A Booster of XGBoost.
@ -1497,16 +1543,20 @@ class Booster(object):
return self.eval_set([(data, name)], iteration)
# pylint: disable=too-many-function-args
def predict(self,
data,
output_margin=False,
ntree_limit=0,
pred_leaf=False,
pred_contribs=False,
approx_contribs=False,
pred_interactions=False,
validate_features=True,
training=False):
def predict(
self,
data: DMatrix,
output_margin: bool = False,
ntree_limit: int = 0,
pred_leaf: bool = False,
pred_contribs: bool = False,
approx_contribs: bool = False,
pred_interactions: bool = False,
validate_features: bool = True,
training: bool = False,
iteration_range: Tuple[int, int] = (0, 0),
strict_shape: bool = False,
) -> np.ndarray:
"""Predict with data.
.. note:: This function is not thread safe except for ``gbtree`` booster.
@ -1518,33 +1568,32 @@ class Booster(object):
Parameters
----------
data : DMatrix
data :
The dmatrix storing the input.
output_margin : bool
output_margin :
Whether to output the raw untransformed margin value.
ntree_limit : int
Limit number of trees in the prediction; defaults to 0 (use all
trees).
ntree_limit :
Deprecated, use `iteration_range` instead.
pred_leaf : bool
pred_leaf :
When this option is on, the output will be a matrix of (nsample,
ntrees) with each record indicating the predicted leaf index of
each sample in each tree. Note that the leaf index of a tree is
unique per tree, so you may find leaf 1 in both tree 1 and tree 0.
pred_contribs : bool
pred_contribs :
When this is True the output will be a matrix of size (nsample,
nfeats + 1) with each record indicating the feature contributions
(SHAP values) for that prediction. The sum of all feature
contributions is equal to the raw untransformed margin value of the
prediction. Note the final column is the bias term.
approx_contribs : bool
approx_contribs :
Approximate the contributions of each feature
pred_interactions : bool
pred_interactions :
When this is True the output will be a matrix of size (nsample,
nfeats + 1, nfeats + 1) indicating the SHAP interaction values for
each pair of features. The sum of each row (or column) of the
@ -1553,17 +1602,33 @@ class Booster(object):
untransformed margin value of the prediction. Note the last row and
column correspond to the bias term.
validate_features : bool
validate_features :
When this is True, validate that the Booster's and data's
feature_names are identical. Otherwise, it is assumed that the
feature_names are the same.
training : bool
training :
Whether the prediction value is used for training. This can effect
`dart` booster, which performs dropouts during training iterations.
.. versionadded:: 1.0.0
iteration_range :
Specifies which layer of trees are used in prediction. For example, if a
random forest is trained with 100 rounds. Specifying `iteration_range=(10,
20)`, then only the forests built during [10, 20) (half open set) rounds are
used in this prediction.
.. versionadded:: 1.4.0
strict_shape :
When set to True, output shape is invariant to whether classification is used.
For both value and margin prediction, the output shape is (n_samples,
n_groups), n_groups == 1 when multi-class is not used. Default to False, in
which case the output shape can be (n_samples, ) if multi-class is not used.
.. versionadded:: 1.4.0
.. note:: Using ``predict()`` with DART booster
If the booster object is DART type, ``predict()`` will not perform
@ -1575,64 +1640,50 @@ class Booster(object):
prediction : numpy array
"""
option_mask = 0x00
if output_margin:
option_mask |= 0x01
if pred_leaf:
option_mask |= 0x02
if pred_contribs:
option_mask |= 0x04
if approx_contribs:
option_mask |= 0x08
if pred_interactions:
option_mask |= 0x10
if not isinstance(data, DMatrix):
raise TypeError('Expecting data to be a DMatrix object, got: ',
type(data))
raise TypeError('Expecting data to be a DMatrix object, got: ', type(data))
if validate_features:
self._validate_features(data)
iteration_range = _convert_ntree_limit(self, ntree_limit, iteration_range)
args = {
"type": 0,
"training": training,
"iteration_begin": iteration_range[0],
"iteration_end": iteration_range[1],
"strict_shape": strict_shape,
}
length = c_bst_ulong()
preds = ctypes.POINTER(ctypes.c_float)()
_check_call(_LIB.XGBoosterPredict(self.handle, data.handle,
ctypes.c_int(option_mask),
ctypes.c_uint(ntree_limit),
ctypes.c_int(training),
ctypes.byref(length),
ctypes.byref(preds)))
preds = ctypes2numpy(preds, length.value, np.float32)
def assign_type(t: int) -> None:
if args["type"] != 0:
raise ValueError("One type of prediction at a time.")
args["type"] = t
if output_margin:
assign_type(1)
if pred_contribs:
assign_type(2 if not approx_contribs else 3)
if pred_interactions:
assign_type(4)
if pred_leaf:
preds = preds.astype(np.int32, copy=False)
nrow = data.num_row()
if preds.size != nrow and preds.size % nrow == 0:
chunk_size = int(preds.size / nrow)
if pred_interactions:
ngroup = int(chunk_size / ((data.num_col() + 1) *
(data.num_col() + 1)))
if ngroup == 1:
preds = preds.reshape(nrow,
data.num_col() + 1,
data.num_col() + 1)
else:
preds = preds.reshape(nrow, ngroup,
data.num_col() + 1,
data.num_col() + 1)
elif pred_contribs:
ngroup = int(chunk_size / (data.num_col() + 1))
if ngroup == 1:
preds = preds.reshape(nrow, data.num_col() + 1)
else:
preds = preds.reshape(nrow, ngroup, data.num_col() + 1)
else:
preds = preds.reshape(nrow, chunk_size)
return preds
assign_type(5)
preds = ctypes.POINTER(ctypes.c_float)()
shape = ctypes.POINTER(c_bst_ulong)()
dims = c_bst_ulong()
_check_call(
_LIB.XGBoosterPredictFromDMatrix(
self.handle,
data.handle,
from_pystr_to_cstr(json.dumps(args)),
ctypes.byref(shape),
ctypes.byref(dims),
ctypes.byref(preds)
)
)
return _prediction_output(shape, dims, preds, False)
def inplace_predict(
self,
data,
data: Any,
iteration_range: Tuple[int, int] = (0, 0),
predict_type: str = "value",
missing: float = np.nan,
@ -1665,26 +1716,24 @@ class Booster(object):
The input data, must not be a view for numpy array. Set
``predictor`` to ``gpu_predictor`` for running prediction on CuPy
array or CuDF DataFrame.
iteration_range : tuple
Specifies which layer of trees are used in prediction. For
example, if a random forest is trained with 100 rounds. Specifying
`iteration_range=(10, 20)`, then only the forests built during [10,
20) (open set) rounds are used in this prediction.
predict_type : str
iteration_range :
See :py:meth:`xgboost.Booster.predict` for details.
predict_type :
* `value` Output model prediction values.
* `margin` Output the raw untransformed margin value.
missing : float
Value in the input data which needs to be present as a missing
value.
missing :
See :py:obj:`xgboost.DMatrix` for details.
validate_features:
See :py:meth:`xgboost.Booster.predict` for details.
base_margin:
See :py:obj:`xgboost.DMatrix` for details.
.. versionadded:: 1.4.0
strict_shape:
When set to True, output shape is invariant to whether classification is used.
For both value and margin prediction, the output shape is (n_samples,
n_groups), n_groups == 1 when multi-class is not used. Default to False, in
which case the output shape can be (n_samples, ) if multi-class is not used.
See :py:meth:`xgboost.Booster.predict` for details.
.. versionadded:: 1.4.0
Returns
-------
@ -1772,7 +1821,7 @@ class Booster(object):
interface["mask"] = interface["mask"].__cuda_array_interface__
interface_str = bytes(json.dumps(interface, indent=2), "utf-8")
_check_call(
_LIB.XGBoosterPredictFromArrayInterface(
_LIB.XGBoosterPredictFromCudaArray(
self.handle,
interface_str,
from_pystr_to_cstr(json.dumps(args)),
@ -1788,7 +1837,7 @@ class Booster(object):
interfaces_str = _cudf_array_interfaces(data)
_check_call(
_LIB.XGBoosterPredictFromArrayInterfaceColumns(
_LIB.XGBoosterPredictFromCudaColumnar(
self.handle,
interfaces_str,
from_pystr_to_cstr(json.dumps(args)),

View File

@ -254,6 +254,7 @@ class DaskDMatrix:
raise TypeError(_expect((dd.DataFrame, da.Array, dd.Series), type(label)))
self._n_cols = data.shape[1]
assert isinstance(self._n_cols, int)
self.worker_map: Dict[str, "distributed.Future"] = defaultdict(list)
self.is_quantile: bool = False
@ -881,7 +882,7 @@ async def _train_async(
return list(filter(lambda ret: ret is not None, results))[0]
def train(
def train( # pylint: disable=unused-argument
client: "distributed.Client",
params: Dict[str, Any],
dtrain: DaskDMatrix,
@ -892,16 +893,17 @@ def train(
early_stopping_rounds: Optional[int] = None,
xgb_model: Optional[Booster] = None,
verbose_eval: Union[int, bool] = True,
callbacks: Optional[List[TrainingCallback]] = None
callbacks: Optional[List[TrainingCallback]] = None,
) -> Any:
'''Train XGBoost model.
"""Train XGBoost model.
.. versionadded:: 1.0.0
.. note::
Other parameters are the same as `xgboost.train` except for `evals_result`, which
is returned as part of function return value instead of argument.
Other parameters are the same as :py:func:`xgboost.train` except for
`evals_result`, which is returned as part of function return value instead of
argument.
Parameters
----------
@ -920,29 +922,17 @@ def train(
{'booster': xgboost.Booster,
'history': {'train': {'logloss': ['0.48253', '0.35953']},
'eval': {'logloss': ['0.480385', '0.357756']}}}
'''
"""
_assert_dask_support()
client = _xgb_get_client(client)
# Get global configuration before transferring computation to another thread or
# process.
global_config = config.get_config()
return client.sync(_train_async,
client=client,
global_config=global_config,
num_boost_round=num_boost_round,
obj=obj,
feval=feval,
params=params,
dtrain=dtrain,
evals=evals,
early_stopping_rounds=early_stopping_rounds,
verbose_eval=verbose_eval,
xgb_model=xgb_model,
callbacks=callbacks)
return client.sync(_train_async, global_config=config.get_config(), **locals())
def _can_output_df(data: _DaskCollection, output_shape: Tuple) -> bool:
return isinstance(data, dd.DataFrame) and len(output_shape) <= 2
def _can_output_df(is_df: bool, output_shape: Tuple) -> bool:
return is_df and len(output_shape) <= 2
async def _direct_predict_impl(
@ -954,8 +944,9 @@ async def _direct_predict_impl(
meta: Dict[int, str],
) -> _DaskCollection:
columns = list(meta.keys())
if _can_output_df(data, output_shape):
if _can_output_df(isinstance(data, dd.DataFrame), output_shape):
if base_margin is not None and isinstance(base_margin, da.Array):
# Easier for map_partitions
base_margin_df: Optional[dd.DataFrame] = base_margin.to_dask_dataframe()
else:
base_margin_df = base_margin
@ -975,17 +966,21 @@ async def _direct_predict_impl(
if base_margin is not None and isinstance(
base_margin, (dd.Series, dd.DataFrame)
):
# Easier for map_blocks
base_margin_array: Optional[da.Array] = base_margin.to_dask_array()
else:
base_margin_array = base_margin
# Input data is 2-dim array, output can be 1(reg, binary)/2(multi-class,
# contrib)/3(contrib)/4(interaction) dims.
# contrib)/3(contrib, interaction)/4(interaction) dims.
if len(output_shape) == 1:
drop_axis: Union[int, List[int]] = [1] # drop from 2 to 1 dim.
new_axis: Union[int, List[int]] = []
else:
drop_axis = []
new_axis = [i + 2 for i in range(len(output_shape) - 2)]
if isinstance(data, dd.DataFrame):
new_axis = list(range(len(output_shape) - 2))
else:
new_axis = [i + 2 for i in range(len(output_shape) - 2)]
predictions = da.map_blocks(
mapped_predict,
booster,
@ -1001,28 +996,21 @@ async def _direct_predict_impl(
def _infer_predict_output(
booster: Booster,
data: Union[DaskDMatrix, _DaskCollection],
inplace: bool,
**kwargs: Any
booster: Booster, features: int, is_df: bool, inplace: bool, **kwargs: Any
) -> Tuple[Tuple[int, ...], Dict[int, str]]:
"""Create a dummy test sample to infer output shape for prediction."""
if isinstance(data, DaskDMatrix):
features = data.num_col()
else:
features = data.shape[1]
assert isinstance(features, int)
rng = numpy.random.RandomState(1994)
test_sample = rng.randn(1, features)
if inplace:
# clear the state to avoid gpu_id, gpu_predictor
booster = Booster(model_file=booster.save_raw())
test_predt = booster.inplace_predict(test_sample, **kwargs)
else:
m = DMatrix(test_sample)
test_predt = booster.predict(m, **kwargs)
kwargs = kwargs.copy()
if kwargs.pop("predict_type") == "margin":
kwargs["output_margin"] = True
m = DMatrix(test_sample)
test_predt = booster.predict(m, validate_features=False, **kwargs)
n_columns = test_predt.shape[1] if len(test_predt.shape) > 1 else 1
meta: Dict[int, str] = {}
if _can_output_df(data, test_predt.shape):
if _can_output_df(is_df, test_predt.shape):
for i in range(n_columns):
meta[i] = "f4"
return test_predt.shape, meta
@ -1034,7 +1022,7 @@ async def _get_model_future(
if isinstance(model, Booster):
booster = await client.scatter(model, broadcast=True)
elif isinstance(model, dict):
booster = await client.scatter(model["booster"])
booster = await client.scatter(model["booster"], broadcast=True)
elif isinstance(model, distributed.Future):
booster = model
if booster.type is not Booster:
@ -1059,6 +1047,8 @@ async def _predict_async(
approx_contribs: bool,
pred_interactions: bool,
validate_features: bool,
iteration_range: Tuple[int, int],
strict_shape: bool,
) -> _DaskCollection:
_booster = await _get_model_future(client, model)
if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
@ -1077,43 +1067,51 @@ async def _predict_async(
approx_contribs=approx_contribs,
pred_interactions=pred_interactions,
validate_features=validate_features,
iteration_range=iteration_range,
strict_shape=strict_shape,
)
if is_df and len(predt.shape) <= 2:
if _can_output_df(is_df, predt.shape):
if lazy_isinstance(partition, "cudf", "core.dataframe.DataFrame"):
import cudf
predt = cudf.DataFrame(predt, columns=columns)
predt = cudf.DataFrame(predt, columns=columns, dtype=numpy.float32)
else:
predt = DataFrame(predt, columns=columns)
predt = DataFrame(predt, columns=columns, dtype=numpy.float32)
return predt
# Predict on dask collection directly.
if isinstance(data, (da.Array, dd.DataFrame)):
_output_shape, meta = _infer_predict_output(
await _booster.result(),
data,
_output_shape, meta = await client.compute(
client.submit(
_infer_predict_output,
_booster,
features=data.shape[1],
is_df=isinstance(data, dd.DataFrame),
inplace=False,
output_margin=output_margin,
pred_leaf=pred_leaf,
pred_contribs=pred_contribs,
approx_contribs=approx_contribs,
pred_interactions=pred_interactions,
)
)
return await _direct_predict_impl(
mapped_predict, _booster, data, None, _output_shape, meta
)
output_shape, _ = await client.compute(
client.submit(
_infer_predict_output,
booster=_booster,
features=data.num_col(),
is_df=False,
inplace=False,
output_margin=output_margin,
pred_leaf=pred_leaf,
pred_contribs=pred_contribs,
approx_contribs=approx_contribs,
pred_interactions=pred_interactions,
validate_features=False,
)
return await _direct_predict_impl(
mapped_predict, _booster, data, None, _output_shape, meta
)
output_shape, _ = _infer_predict_output(
booster=await _booster.result(),
data=data,
inplace=False,
output_margin=output_margin,
pred_leaf=pred_leaf,
pred_contribs=pred_contribs,
approx_contribs=approx_contribs,
pred_interactions=pred_interactions,
validate_features=False,
)
# Prediction on dask DMatrix.
partition_order = data.partition_order
@ -1180,7 +1178,7 @@ async def _predict_async(
futures[i], shape=(rows,) + output_shape[1:], dtype=numpy.float32
)
)
predictions = await da.concatenate(arrays, axis=0)
predictions = da.concatenate(arrays, axis=0)
return predictions
@ -1194,15 +1192,19 @@ def predict( # pylint: disable=unused-argument
pred_contribs: bool = False,
approx_contribs: bool = False,
pred_interactions: bool = False,
validate_features: bool = True
validate_features: bool = True,
iteration_range: Tuple[int, int] = (0, 0),
strict_shape: bool = False,
) -> Any:
'''Run prediction with a trained booster.
.. note::
Using ``inplace_predict `` might be faster when meta information like
``base_margin`` is not needed. For other parameters, please see
``Booster.predict``.
Using ``inplace_predict`` might be faster when some features are not needed. See
:py:meth:`xgboost.Booster.predict` for details on various parameters. When using
``pred_interactions`` with mutli-class model, input should be ``da.Array`` or
``DaskDMatrix`` due to limitation in ``da.map_blocks``.
.. versionadded:: 1.0.0
@ -1232,69 +1234,83 @@ def predict( # pylint: disable=unused-argument
'''
_assert_dask_support()
client = _xgb_get_client(client)
return client.sync(
_predict_async, global_config=config.get_config(), **locals()
)
return client.sync(_predict_async, global_config=config.get_config(), **locals())
async def _inplace_predict_async(
async def _inplace_predict_async( # pylint: disable=too-many-branches
client: "distributed.Client",
global_config: Dict[str, Any],
model: Union[Booster, Dict, "distributed.Future"],
data: _DaskCollection,
iteration_range: Tuple[int, int] = (0, 0),
predict_type: str = 'value',
missing: float = numpy.nan
iteration_range: Tuple[int, int],
predict_type: str,
missing: float,
validate_features: bool,
base_margin: Optional[_DaskCollection],
strict_shape: bool,
) -> _DaskCollection:
client = _xgb_get_client(client)
booster = await _get_model_future(client, model)
if not isinstance(data, (da.Array, dd.DataFrame)):
raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))
if base_margin is not None and not isinstance(
data, (da.Array, dd.DataFrame, dd.Series)
):
raise TypeError(_expect([da.Array, dd.DataFrame, dd.Series], type(base_margin)))
def mapped_predict(
booster: Booster, data: Any, is_df: bool, columns: List[int], _: Any
booster: Booster, data: Any, is_df: bool, columns: List[int], base_margin: Any
) -> Any:
with config.config_context(**global_config):
prediction = booster.inplace_predict(
data,
iteration_range=iteration_range,
predict_type=predict_type,
missing=missing
missing=missing,
base_margin=base_margin,
validate_features=validate_features,
strict_shape=strict_shape,
)
if is_df and len(prediction.shape) <= 2:
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
if _can_output_df(is_df, prediction.shape):
if lazy_isinstance(data, "cudf.core.dataframe", "DataFrame"):
import cudf
prediction = cudf.DataFrame(
prediction, columns=columns, dtype=numpy.float32
)
else:
# If it's from pandas, the partition is a numpy array
prediction = DataFrame(
prediction, columns=columns, dtype=numpy.float32
)
# If it's from pandas, the partition is a numpy array
prediction = DataFrame(prediction, columns=columns, dtype=numpy.float32)
return prediction
shape, meta = _infer_predict_output(
await booster.result(),
data,
True,
predict_type=predict_type,
iteration_range=iteration_range
# await turns future into value.
shape, meta = await client.compute(
client.submit(
_infer_predict_output,
booster,
features=data.shape[1],
is_df=isinstance(data, dd.DataFrame),
inplace=True,
predict_type=predict_type,
iteration_range=iteration_range,
)
)
return await _direct_predict_impl(
mapped_predict, booster, data, None, shape, meta
mapped_predict, booster, data, base_margin, shape, meta
)
def inplace_predict( # pylint: disable=unused-argument
def inplace_predict( # pylint: disable=unused-argument
client: "distributed.Client",
model: Union[TrainReturnT, Booster, "distributed.Future"],
data: _DaskCollection,
iteration_range: Tuple[int, int] = (0, 0),
predict_type: str = 'value',
missing: float = numpy.nan
predict_type: str = "value",
missing: float = numpy.nan,
validate_features: bool = True,
base_margin: Optional[_DaskCollection] = None,
strict_shape: bool = False,
) -> Any:
'''Inplace prediction.
"""Inplace prediction. See doc in :py:meth:`xgboost.Booster.inplace_predict` for details.
.. versionadded:: 1.1.0
@ -1304,16 +1320,27 @@ def inplace_predict( # pylint: disable=unused-argument
Specify the dask client used for training. Use default client
returned from dask if it's set to None.
model:
The trained model. It can be a distributed.Future so user can
pre-scatter it onto all workers.
See :py:func:`xgboost.dask.predict` for details.
data :
dask collection.
iteration_range:
Specify the range of trees used for prediction.
See :py:meth:`xgboost.Booster.predict` for details.
predict_type:
* 'value': Normal prediction result.
* 'margin': Output the raw untransformed margin value.
See :py:meth:`xgboost.Booster.inplace_predict` for details.
missing:
Value in the input data which needs to be present as a missing
value. If None, defaults to np.nan.
base_margin:
See :py:obj:`xgboost.DMatrix` for details. Right now classifier is not well
supported with base_margin as it requires the size of base margin to be `n_classes
* n_samples`.
.. versionadded:: 1.4.0
strict_shape:
See :py:meth:`xgboost.Booster.predict` for details.
.. versionadded:: 1.4.0
Returns
-------
@ -1322,7 +1349,7 @@ def inplace_predict( # pylint: disable=unused-argument
data is ``dask.dataframe.DataFrame``, return value can be
``dask.dataframe.Series``, ``dask.dataframe.DataFrame`` or ``dask.array.Array``,
depending on the output shape.
'''
"""
_assert_dask_support()
client = _xgb_get_client(client)
return client.sync(
@ -1334,9 +1361,11 @@ async def _async_wrap_evaluation_matrices(
client: "distributed.Client", **kwargs: Any
) -> Tuple[DaskDMatrix, Optional[List[Tuple[DaskDMatrix, str]]]]:
"""A switch function for async environment."""
def _inner(**kwargs: Any) -> DaskDMatrix:
m = DaskDMatrix(client=client, **kwargs)
return m
train_dmatrix, evals = _wrap_evaluation_matrices(create_dmatrix=_inner, **kwargs)
train_dmatrix = await train_dmatrix
if evals is None:
@ -1351,25 +1380,45 @@ async def _async_wrap_evaluation_matrices(
class DaskScikitLearnBase(XGBModel):
'''Base class for implementing scikit-learn interface with Dask'''
"""Base class for implementing scikit-learn interface with Dask"""
_client = None
async def _predict_async(
self, data: _DaskCollection,
output_margin: bool = False,
validate_features: bool = True,
base_margin: Optional[_DaskCollection] = None
self,
data: _DaskCollection,
output_margin: bool,
validate_features: bool,
base_margin: Optional[_DaskCollection],
iteration_range: Optional[Tuple[int, int]],
) -> Any:
test_dmatrix = await DaskDMatrix(
client=self.client, data=data, base_margin=base_margin,
missing=self.missing
)
pred_probs = await predict(client=self.client,
model=self.get_booster(), data=test_dmatrix,
output_margin=output_margin,
validate_features=validate_features)
return pred_probs
iteration_range = self._get_iteration_range(iteration_range)
if self._can_use_inplace_predict():
predts = await inplace_predict(
client=self.client,
model=self.get_booster(),
data=data,
iteration_range=iteration_range,
predict_type="margin" if output_margin else "value",
missing=self.missing,
base_margin=base_margin,
validate_features=validate_features,
)
if isinstance(predts, dd.DataFrame):
predts = predts.to_dask_array()
else:
test_dmatrix = await DaskDMatrix(
self.client, data=data, base_margin=base_margin, missing=self.missing
)
predts = await predict(
self.client,
model=self.get_booster(),
data=test_dmatrix,
output_margin=output_margin,
validate_features=validate_features,
iteration_range=iteration_range,
)
return predts
def predict(
self,
@ -1377,26 +1426,56 @@ class DaskScikitLearnBase(XGBModel):
output_margin: bool = False,
ntree_limit: Optional[int] = None,
validate_features: bool = True,
base_margin: Optional[_DaskCollection] = None
base_margin: Optional[_DaskCollection] = None,
iteration_range: Optional[Tuple[int, int]] = None,
) -> Any:
_assert_dask_support()
msg = '`ntree_limit` is not supported on dask, use model slicing instead.'
msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead."
assert ntree_limit is None, msg
return self.client.sync(
self._predict_async,
X,
output_margin=output_margin,
validate_features=validate_features,
base_margin=base_margin
base_margin=base_margin,
iteration_range=iteration_range,
)
async def _apply_async(
self,
X: _DaskCollection,
iteration_range: Optional[Tuple[int, int]] = None,
) -> Any:
iteration_range = self._get_iteration_range(iteration_range)
test_dmatrix = await DaskDMatrix(self.client, data=X, missing=self.missing)
predts = await predict(
self.client,
model=self.get_booster(),
data=test_dmatrix,
pred_leaf=True,
iteration_range=iteration_range,
)
return predts
def apply(
self,
X: _DaskCollection,
ntree_limit: Optional[int] = None,
iteration_range: Optional[Tuple[int, int]] = None,
) -> Any:
_assert_dask_support()
msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead."
assert ntree_limit is None, msg
return self.client.sync(self._apply_async, X, iteration_range=iteration_range)
def __await__(self) -> Awaitable[Any]:
# Generate a coroutine wrapper to make this class awaitable.
async def _() -> Awaitable[Any]:
return self
return self.client.sync(_).__await__()
def __getstate__(self):
def __getstate__(self) -> Dict:
this = self.__dict__.copy()
if "_client" in this.keys():
del this["_client"]
@ -1404,7 +1483,7 @@ class DaskScikitLearnBase(XGBModel):
@property
def client(self) -> "distributed.Client":
'''The dask client used in this model.'''
"""The dask client used in this model."""
client = _xgb_get_client(self._client)
return client
@ -1494,7 +1573,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
sample_weight_eval_set: Optional[List[_DaskCollection]] = None,
base_margin_eval_set: Optional[List[_DaskCollection]] = None,
feature_weights: Optional[_DaskCollection] = None,
callbacks: Optional[List[TrainingCallback]] = None
callbacks: Optional[List[TrainingCallback]] = None,
) -> "DaskXGBRegressor":
_assert_dask_support()
args = {k: v for k, v in locals().items() if k != "self"}
@ -1556,9 +1635,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
else:
obj = None
model, metric, params = self._configure_fit(
booster=xgb_model,
eval_metric=eval_metric,
params=params
booster=xgb_model, eval_metric=eval_metric, params=params
)
results = await train(
client=self.client,
@ -1610,18 +1687,19 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
X: _DaskCollection,
validate_features: bool,
output_margin: bool,
base_margin: Optional[_DaskCollection]
base_margin: Optional[_DaskCollection],
iteration_range: Optional[Tuple[int, int]],
) -> _DaskCollection:
test_dmatrix = await DaskDMatrix(
client=self.client, data=X, base_margin=base_margin,
missing=self.missing
if iteration_range is None:
iteration_range = (0, 0)
predts = await super()._predict_async(
data=X,
output_margin=output_margin,
validate_features=validate_features,
base_margin=base_margin,
iteration_range=iteration_range,
)
pred_probs = await predict(client=self.client,
model=self.get_booster(),
data=test_dmatrix,
validate_features=validate_features,
output_margin=output_margin)
return _cls_predict_proba(self.objective, pred_probs, da.vstack)
return _cls_predict_proba(self.objective, predts, da.vstack)
# pylint: disable=missing-function-docstring
def predict_proba(
@ -1630,37 +1708,49 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
ntree_limit: Optional[int] = None,
validate_features: bool = True,
output_margin: bool = False,
base_margin: Optional[_DaskCollection] = None
base_margin: Optional[_DaskCollection] = None,
iteration_range: Optional[Tuple[int, int]] = None,
) -> Any:
_assert_dask_support()
msg = '`ntree_limit` is not supported on dask, use model slicing instead.'
msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead."
assert ntree_limit is None, msg
return self.client.sync(
self._predict_proba_async,
X=X,
validate_features=validate_features,
output_margin=output_margin,
base_margin=base_margin
base_margin=base_margin,
iteration_range=iteration_range,
)
predict_proba.__doc__ = XGBClassifier.predict_proba.__doc__
async def _predict_async(
self, data: _DaskCollection,
output_margin: bool = False,
validate_features: bool = True,
base_margin: Optional[_DaskCollection] = None
self,
data: _DaskCollection,
output_margin: bool,
validate_features: bool,
base_margin: Optional[_DaskCollection],
iteration_range: Optional[Tuple[int, int]],
) -> _DaskCollection:
pred_probs = await super()._predict_async(
data, output_margin, validate_features, base_margin
data, output_margin, validate_features, base_margin, iteration_range
)
if output_margin:
return pred_probs
if self.n_classes_ == 2:
if len(pred_probs.shape) == 1:
preds = (pred_probs > 0.5).astype(int)
else:
preds = da.argmax(pred_probs, axis=1)
assert len(pred_probs.shape) == 2
assert isinstance(pred_probs, da.Array)
# when using da.argmax directly, dask will construct a numpy based return
# array, which runs into error when computing GPU based prediction.
def _argmax(x: Any) -> Any:
return x.argmax(axis=1)
preds = da.map_blocks(_argmax, pred_probs, drop_axis=1)
return preds
@ -1770,7 +1860,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
callbacks: Optional[List[TrainingCallback]] = None
) -> "DaskXGBRanker":
_assert_dask_support()
args = {k: v for k, v in locals().items() if k != 'self'}
args = {k: v for k, v in locals().items() if k != "self"}
return self.client.sync(self._fit_async, **args)
# FIXME(trivialfis): arguments differ due to additional parameters like group and qid.

View File

@ -6,7 +6,8 @@ import warnings
import json
from typing import Union, Optional, List, Dict, Callable, Tuple, Any
import numpy as np
from .core import Booster, DMatrix, XGBoostError, _deprecate_positional_args
from .core import Booster, DMatrix, XGBoostError
from .core import _deprecate_positional_args, _convert_ntree_limit
from .core import Metric
from .training import train
from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array
@ -413,8 +414,8 @@ class XGBModel(XGBModelBase):
# Simple optimization to gain speed (inspect is slow)
return self
# this concatenates kwargs into paraemters, enabling `get_params` for
# obtaining parameters from keyword paraemters.
# this concatenates kwargs into parameters, enabling `get_params` for
# obtaining parameters from keyword parameters.
for key, value in params.items():
if hasattr(self, key):
setattr(self, key, value)
@ -747,26 +748,45 @@ class XGBModel(XGBModelBase):
self._set_evaluation_result(evals_result)
return self
def _can_use_inplace_predict(self) -> bool:
# When predictor is explicitly set, using `inplace_predict` might result into
# error with incompatible data type.
# Inplace predict doesn't handle as many data types as DMatrix, but it's
# sufficient for dask interface where input is simpiler.
params = self.get_params()
booster = self.booster
if params.get("predictor", None) is None and (
booster is None or booster == "gbtree"
):
return True
return False
def _get_iteration_range(
self, iteration_range: Optional[Tuple[int, int]]
) -> Tuple[int, int]:
if (iteration_range is None or iteration_range[1] == 0):
# Use best_iteration if defined.
try:
iteration_range = (0, self.best_iteration + 1)
except AttributeError:
iteration_range = (0, 0)
if self.booster == "gblinear":
iteration_range = (0, 0)
return iteration_range
def predict(
self,
X,
output_margin=False,
ntree_limit=None,
validate_features=True,
base_margin=None
base_margin=None,
iteration_range=None,
):
"""
Predict with `X`.
.. note:: This function is not thread safe.
For each booster object, predict can only be called from one thread.
If you want to run prediction using multiple thread, call ``xgb.copy()`` to make copies
of model object and then call ``predict()``.
.. code-block:: python
preds = bst.predict(dtest, ntree_limit=num_round)
.. note:: This function is only thread safe for `gbtree`
Parameters
----------
@ -775,37 +795,40 @@ class XGBModel(XGBModelBase):
output_margin : bool
Whether to output the raw untransformed margin value.
ntree_limit : int
Limit number of trees in the prediction; defaults to best_ntree_limit if
defined (i.e. it has been trained with early stopping), otherwise 0 (use all
trees).
Deprecated, use `iteration_range` instead.
validate_features : bool
When this is True, validate that the Booster's and data's feature_names are identical.
Otherwise, it is assumed that the feature_names are the same.
When this is True, validate that the Booster's and data's feature_names are
identical. Otherwise, it is assumed that the feature_names are the same.
base_margin : array_like
Margin added to prediction.
iteration_range :
Specifies which layer of trees are used in prediction. For example, if a
random forest is trained with 100 rounds. Specifying `iteration_range=(10,
20)`, then only the forests built during [10, 20) (half open set) rounds are
used in this prediction.
.. versionadded:: 1.4.0
Returns
-------
prediction : numpy array
"""
# pylint: disable=missing-docstring,invalid-name
test_dmatrix = DMatrix(X, base_margin=base_margin,
missing=self.missing, nthread=self.n_jobs)
# get ntree_limit to use - if none specified, default to
# best_ntree_limit if defined, otherwise 0.
if ntree_limit is None:
try:
ntree_limit = self.best_ntree_limit
except AttributeError:
ntree_limit = 0
iteration_range = _convert_ntree_limit(
self.get_booster(), ntree_limit, iteration_range
)
iteration_range = self._get_iteration_range(iteration_range)
test = DMatrix(
X, base_margin=base_margin, missing=self.missing, nthread=self.n_jobs
)
return self.get_booster().predict(
test_dmatrix,
data=test,
iteration_range=iteration_range,
output_margin=output_margin,
ntree_limit=ntree_limit,
validate_features=validate_features
validate_features=validate_features,
)
def apply(self, X, ntree_limit=0):
def apply(
self, X, ntree_limit: int = 0, iteration_range: Optional[Tuple[int, int]] = None
) -> np.ndarray:
"""Return the predicted leaf every tree for each sample.
Parameters
@ -823,10 +846,16 @@ class XGBModel(XGBModelBase):
leaf x ends up in. Leaves are numbered within
``[0; 2**(self.max_depth+1))``, possibly with gaps in the numbering.
"""
iteration_range = _convert_ntree_limit(
self.get_booster(), ntree_limit, iteration_range
)
iteration_range = self._get_iteration_range(iteration_range)
test_dmatrix = DMatrix(X, missing=self.missing, nthread=self.n_jobs)
return self.get_booster().predict(test_dmatrix,
pred_leaf=True,
ntree_limit=ntree_limit)
return self.get_booster().predict(
test_dmatrix,
pred_leaf=True,
iteration_range=iteration_range
)
def evals_result(self):
"""Return the evaluation results.
@ -945,8 +974,7 @@ class XGBModel(XGBModelBase):
'Coefficients are not defined for Booster type {}'
.format(self.booster))
b = self.get_booster()
coef = np.array(json.loads(
b.get_dump(dump_format='json')[0])['weight'])
coef = np.array(json.loads(b.get_dump(dump_format='json')[0])['weight'])
# Logic for multiclass classification
n_classes = getattr(self, 'n_classes_', None)
if n_classes is not None:
@ -1157,14 +1185,16 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
output_margin=False,
ntree_limit=None,
validate_features=True,
base_margin=None
base_margin=None,
iteration_range: Optional[Tuple[int, int]] = None,
):
class_probs = super().predict(
X=X,
output_margin=output_margin,
ntree_limit=ntree_limit,
validate_features=validate_features,
base_margin=base_margin
base_margin=base_margin,
iteration_range=iteration_range,
)
if output_margin:
# If output_margin is active, simply return the scores
@ -1180,29 +1210,34 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
return self._le.inverse_transform(column_indexes)
return column_indexes
def predict_proba(self, X, ntree_limit=None, validate_features=False,
base_margin=None):
def predict_proba(
self,
X,
ntree_limit=None,
validate_features=False,
base_margin=None,
iteration_range: Optional[Tuple[int, int]] = None,
) -> np.ndarray:
""" Predict the probability of each `X` example being of a given class.
.. note:: This function is not thread safe
For each booster object, predict can only be called from one
thread. If you want to run prediction using multiple thread, call
``xgb.copy()`` to make copies of model object and then call predict
.. note:: This function is only thread safe for `gbtree`
Parameters
----------
X : array_like
Feature matrix.
ntree_limit : int
Limit number of trees in the prediction; defaults to best_ntree_limit if
defined (i.e. it has been trained with early stopping), otherwise 0 (use all
trees).
Deprecated, use `iteration_range` instead.
validate_features : bool
When this is True, validate that the Booster's and data's feature_names are
identical. Otherwise, it is assumed that the feature_names are the same.
base_margin : array_like
Margin added to prediction.
iteration_range :
Specifies which layer of trees are used in prediction. For example, if a
random forest is trained with 100 rounds. Specifying `iteration_range=(10,
20)`, then only the forests built during [10, 20) (half open set) rounds are
used in this prediction.
Returns
-------
@ -1215,7 +1250,8 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
output_margin=False,
ntree_limit=ntree_limit,
validate_features=validate_features,
base_margin=base_margin
base_margin=base_margin,
iteration_range=iteration_range
)
return _cls_predict_proba(self.objective, class_probs, np.vstack)

View File

@ -4,9 +4,8 @@
"""Training Library containing training routines."""
import warnings
import copy
import json
import numpy as np
from .core import Booster, XGBoostError
from .core import Booster, XGBoostError, _get_booster_layer_trees
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
from . import callback
@ -91,24 +90,7 @@ def _train_internal(params, dtrain,
# These should be moved into callback functions `after_training`, but until old
# callbacks are removed, the train function is the only place for setting the
# attributes.
config = json.loads(bst.save_config())
booster = config['learner']['gradient_booster']['name']
if booster == 'gblinear':
num_parallel_tree = 0
elif booster == 'dart':
num_parallel_tree = int(
config['learner']['gradient_booster']['gbtree']['gbtree_train_param'][
'num_parallel_tree'
]
)
elif booster == 'gbtree':
num_parallel_tree = int(
config['learner']['gradient_booster']['gbtree_train_param'][
'num_parallel_tree']
)
else:
raise ValueError(f'Unknown booster: {booster}')
num_parallel_tree, _ = _get_booster_layer_trees(bst)
if bst.attr('best_score') is not None:
bst.best_score = float(bst.attr('best_score'))
bst.best_iteration = int(bst.attr('best_iteration'))

View File

@ -619,20 +619,58 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
CHECK_HANDLE();
auto *learner = static_cast<Learner*>(handle);
auto& entry = learner->GetThreadLocal().prediction_entry;
learner->Predict(
*static_cast<std::shared_ptr<DMatrix>*>(dmat),
(option_mask & 1) != 0,
&entry.predictions, ntree_limit,
static_cast<bool>(training),
(option_mask & 2) != 0,
(option_mask & 4) != 0,
(option_mask & 8) != 0,
(option_mask & 16) != 0);
auto iteration_end = GetIterationFromTreeLimit(ntree_limit, learner);
learner->Predict(*static_cast<std::shared_ptr<DMatrix> *>(dmat),
(option_mask & 1) != 0, &entry.predictions, 0, iteration_end,
static_cast<bool>(training), (option_mask & 2) != 0,
(option_mask & 4) != 0, (option_mask & 8) != 0,
(option_mask & 16) != 0);
*out_result = dmlc::BeginPtr(entry.predictions.ConstHostVector());
*len = static_cast<xgboost::bst_ulong>(entry.predictions.Size());
API_END();
}
XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle,
DMatrixHandle dmat,
char const* c_json_config,
xgboost::bst_ulong const **out_shape,
xgboost::bst_ulong *out_dim,
bst_float const **out_result) {
API_BEGIN();
if (handle == nullptr) {
LOG(FATAL) << "Booster has not been intialized or has already been disposed.";
}
if (dmat == nullptr) {
LOG(FATAL) << "DMatrix has not been intialized or has already been disposed.";
}
auto config = Json::Load(StringView{c_json_config});
auto *learner = static_cast<Learner*>(handle);
auto& entry = learner->GetThreadLocal().prediction_entry;
auto p_m = *static_cast<std::shared_ptr<DMatrix> *>(dmat);
auto type = PredictionType(get<Integer const>(config["type"]));
auto iteration_begin = get<Integer const>(config["iteration_begin"]);
auto iteration_end = get<Integer const>(config["iteration_end"]);
learner->Predict(
*static_cast<std::shared_ptr<DMatrix> *>(dmat),
type == PredictionType::kMargin, &entry.predictions, iteration_begin,
iteration_end, get<Boolean const>(config["training"]),
type == PredictionType::kLeaf, type == PredictionType::kContribution,
type == PredictionType::kApproxContribution,
type == PredictionType::kInteraction);
*out_result = dmlc::BeginPtr(entry.predictions.ConstHostVector());
auto &shape = learner->GetThreadLocal().prediction_shape;
auto chunksize = p_m->Info().num_row_ == 0 ? 0 : entry.predictions.Size() / p_m->Info().num_row_;
auto rounds = iteration_end - iteration_begin;
rounds = rounds == 0 ? learner->BoostedRounds() : rounds;
// Determine shape
bool strict_shape = get<Boolean const>(config["strict_shape"]);
CalcPredictShape(strict_shape, type, p_m->Info().num_row_,
p_m->Info().num_col_, chunksize, learner->Groups(), rounds,
&shape, out_dim);
*out_shape = dmlc::BeginPtr(shape);
API_END();
}
template <typename T>
void InplacePredictImpl(std::shared_ptr<T> x, std::shared_ptr<DMatrix> p_m,
@ -705,7 +743,7 @@ XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, char const *indptr,
}
#if !defined(XGBOOST_USE_CUDA)
XGB_DLL int XGBoosterPredictFromArrayInterface(
XGB_DLL int XGBoosterPredictFromCUDAArray(
BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
DMatrixHandle m, xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim,
const float **out_result) {
@ -715,7 +753,7 @@ XGB_DLL int XGBoosterPredictFromArrayInterface(
API_END();
}
XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(
XGB_DLL int XGBoosterPredictFromCUDAColumnar(
BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
DMatrixHandle m, xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim,
const float **out_result) {

View File

@ -66,8 +66,7 @@ int InplacePreidctCuda(BoosterHandle handle, char const *c_json_strs,
API_END();
}
// A hidden API as cache id is not being supported yet.
XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(
XGB_DLL int XGBoosterPredictFromCudaColumnar(
BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
DMatrixHandle m, xgboost::bst_ulong const **out_shape,
xgboost::bst_ulong *out_dim, const float **out_result) {
@ -79,8 +78,7 @@ XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(
handle, c_json_strs, c_json_config, p_m, out_shape, out_dim, out_result);
}
// A hidden API as cache id is not being supported yet.
XGB_DLL int XGBoosterPredictFromArrayInterface(
XGB_DLL int XGBoosterPredictFromCudaArray(
BoosterHandle handle, char const *c_json_strs, char const *c_json_config,
DMatrixHandle m, xgboost::bst_ulong const **out_shape,
xgboost::bst_ulong *out_dim, const float **out_result) {

View File

@ -9,6 +9,7 @@
#include <vector>
#include "xgboost/logging.h"
#include "xgboost/json.h"
#include "xgboost/learner.h"
namespace xgboost {
@ -30,8 +31,8 @@ inline void CalcPredictShape(bool strict_shape, PredictionType type, size_t rows
std::vector<bst_ulong> *out_shape,
xgboost::bst_ulong *out_dim) {
auto &shape = *out_shape;
if ((type == PredictionType::kMargin || type == PredictionType::kValue) &&
rows != 0) {
if (type == PredictionType::kMargin && rows != 0) {
// When kValue is used, softmax can change the chunksize.
CHECK_EQ(chunksize, groups);
}
@ -110,5 +111,35 @@ inline void CalcPredictShape(bool strict_shape, PredictionType type, size_t rows
std::accumulate(shape.cbegin(), shape.cend(), 1, std::multiplies<>{}),
chunksize * rows);
}
// Reverse the ntree_limit in old prediction API.
inline uint32_t GetIterationFromTreeLimit(uint32_t ntree_limit, Learner *learner) {
// On Python and R, `best_ntree_limit` is set to `best_iteration * num_parallel_tree`.
// To reverse it we just divide it by `num_parallel_tree`.
if (ntree_limit != 0) {
learner->Configure();
uint32_t num_parallel_tree = 0;
Json config{Object()};
learner->SaveConfig(&config);
auto const &booster =
get<String const>(config["learner"]["gradient_booster"]["name"]);
if (booster == "gblinear") {
num_parallel_tree = 0;
} else if (booster == "dart") {
num_parallel_tree = std::stoi(
get<String const>(config["learner"]["gradient_booster"]["gbtree"]
["gbtree_train_param"]["num_parallel_tree"]));
} else if (booster == "gbtree") {
num_parallel_tree = std::stoi(get<String const>(
(config["learner"]["gradient_booster"]["gbtree_train_param"]
["num_parallel_tree"])));
} else {
LOG(FATAL) << "Unknown booster:" << booster;
}
ntree_limit /= std::max(num_parallel_tree, 1u);
}
return ntree_limit;
}
} // namespace xgboost
#endif // XGBOOST_C_API_C_API_UTILS_H_

View File

@ -25,6 +25,7 @@
#include "common/config.h"
#include "common/io.h"
#include "common/version.h"
#include "c_api/c_api_utils.h"
namespace xgboost {
enum CLITask {
@ -58,6 +59,8 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
int dsplit;
/*!\brief limit number of trees in prediction */
int ntree_limit;
int iteration_begin;
int iteration_end;
/*!\brief whether to directly output margin value */
bool pred_margin;
/*! \brief whether dump statistics along with model */
@ -109,7 +112,11 @@ struct CLIParam : public XGBoostParameter<CLIParam> {
.add_enum("row", 2)
.describe("Data split mode.");
DMLC_DECLARE_FIELD(ntree_limit).set_default(0).set_lower_bound(0)
.describe("Number of trees used for prediction, 0 means use all trees.");
.describe("(Deprecated) Use iteration_begin/iteration_end instead.");
DMLC_DECLARE_FIELD(iteration_begin).set_default(0).set_lower_bound(0)
.describe("Begining of boosted tree iteration used for prediction.");
DMLC_DECLARE_FIELD(iteration_end).set_default(0).set_lower_bound(0)
.describe("End of boosted tree iteration used for prediction. 0 means all the trees.");
DMLC_DECLARE_FIELD(pred_margin).set_default(false)
.describe("Whether to predict margin value instead of probability.");
DMLC_DECLARE_FIELD(dump_stats).set_default(false)
@ -334,7 +341,13 @@ class CLI {
LOG(INFO) << "Start prediction...";
HostDeviceVector<bst_float> preds;
learner_->Predict(dtest, param_.pred_margin, &preds, param_.ntree_limit);
if (param_.ntree_limit != 0) {
param_.iteration_end = GetIterationFromTreeLimit(param_.ntree_limit, learner_.get());
LOG(WARNING) << "`ntree_limit` is deprecated, use `iteration_begin` and "
"`iteration_end` instead.";
}
learner_->Predict(dtest, param_.pred_margin, &preds, param_.iteration_begin,
param_.iteration_end);
LOG(CONSOLE) << "Writing prediction to " << param_.name_pred;
std::unique_ptr<dmlc::Stream> fo(

View File

@ -47,6 +47,12 @@ struct GBLinearTrainParam : public XGBoostParameter<GBLinearTrainParam> {
.describe("Maximum rows per batch.");
}
};
void LinearCheckLayer(unsigned layer_begin, unsigned layer_end) {
CHECK_EQ(layer_begin, 0) << "Linear booster does not support prediction range.";
CHECK_EQ(layer_end, 0) << "Linear booster does not support prediction range.";
}
/*!
* \brief gradient boosted linear model
*/
@ -130,20 +136,19 @@ class GBLinear : public GradientBooster {
monitor_.Stop("DoBoost");
}
void PredictBatch(DMatrix *p_fmat,
PredictionCacheEntry *predts,
bool, unsigned ntree_limit) override {
void PredictBatch(DMatrix *p_fmat, PredictionCacheEntry *predts,
bool training, unsigned layer_begin, unsigned layer_end) override {
monitor_.Start("PredictBatch");
LinearCheckLayer(layer_begin, layer_end);
auto* out_preds = &predts->predictions;
CHECK_EQ(ntree_limit, 0U)
<< "GBLinear::Predict ntrees is only valid for gbtree predictor";
this->PredictBatchInternal(p_fmat, &out_preds->HostVector());
monitor_.Stop("PredictBatch");
}
// add base margin
void PredictInstance(const SparsePage::Inst &inst,
std::vector<bst_float> *out_preds,
unsigned) override {
unsigned layer_begin, unsigned layer_end) override {
LinearCheckLayer(layer_begin, layer_end);
const int ngroup = model_.learner_model_param->num_output_group;
for (int gid = 0; gid < ngroup; ++gid) {
this->Pred(inst, dmlc::BeginPtr(*out_preds), gid,
@ -151,16 +156,15 @@ class GBLinear : public GradientBooster {
}
}
void PredictLeaf(DMatrix *, HostDeviceVector<bst_float> *, unsigned) override {
void PredictLeaf(DMatrix *, HostDeviceVector<bst_float> *, unsigned, unsigned) override {
LOG(FATAL) << "gblinear does not support prediction of leaf index";
}
void PredictContribution(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_contribs,
unsigned ntree_limit, bool, int, unsigned) override {
unsigned layer_begin, unsigned layer_end, bool, int, unsigned) override {
model_.LazyInitModel();
CHECK_EQ(ntree_limit, 0U)
<< "GBLinear::PredictContribution: ntrees is only valid for gbtree predictor";
LinearCheckLayer(layer_begin, layer_end);
const auto& base_margin = p_fmat->Info().base_margin_.ConstHostVector();
const int ngroup = model_.learner_model_param->num_output_group;
const size_t ncolumns = model_.learner_model_param->num_feature + 1;
@ -197,7 +201,8 @@ class GBLinear : public GradientBooster {
void PredictInteractionContributions(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_contribs,
unsigned, bool) override {
unsigned layer_begin, unsigned layer_end, bool) override {
LinearCheckLayer(layer_begin, layer_end);
std::vector<bst_float>& contribs = out_contribs->HostVector();
// linear models have no interaction effects

View File

@ -414,7 +414,7 @@ void GBTree::Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
auto layer_trees = this->LayerTrees();
layer_end = layer_end == 0 ? model_.trees.size() / layer_trees : layer_end;
CHECK_GE(layer_end, layer_begin);
CHECK_GT(layer_end, layer_begin);
CHECK_GE(step, 1);
int32_t n_layers = (layer_end - layer_begin) / step;
std::vector<std::unique_ptr<RegTree>> &out_trees = out_model.trees;
@ -438,10 +438,35 @@ void GBTree::Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
void GBTree::PredictBatch(DMatrix* p_fmat,
PredictionCacheEntry* out_preds,
bool,
unsigned ntree_limit) {
unsigned layer_begin,
unsigned layer_end) {
CHECK(configured_);
if (layer_end == 0) {
layer_end = this->BoostedRounds();
}
if (layer_begin != 0 || layer_end < out_preds->version) {
// cache is dropped.
out_preds->version = 0;
}
bool reset = false;
if (layer_begin == 0) {
layer_begin = out_preds->version;
} else {
// When begin layer is not 0, the cache is not useful.
reset = true;
}
uint32_t tree_begin, tree_end;
std::tie(tree_begin, tree_end) =
detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
GetPredictor(&out_preds->predictions, p_fmat)
->PredictBatch(p_fmat, out_preds, model_, 0, ntree_limit);
->PredictBatch(p_fmat, out_preds, model_, tree_begin, tree_end);
if (reset) {
out_preds->version = 0;
} else {
uint32_t delta = layer_end - out_preds->version;
out_preds->Update(delta);
}
}
std::unique_ptr<Predictor> const &
@ -603,13 +628,14 @@ class Dart : public GBTree {
void PredictBatch(DMatrix* p_fmat,
PredictionCacheEntry* p_out_preds,
bool training,
unsigned ntree_limit) override {
unsigned layer_begin,
unsigned layer_end) override {
DropTrees(training);
int num_group = model_.learner_model_param->num_output_group;
ntree_limit *= num_group;
if (ntree_limit == 0 || ntree_limit > model_.trees.size()) {
ntree_limit = static_cast<unsigned>(model_.trees.size());
}
uint32_t tree_begin, tree_end;
std::tie(tree_begin, tree_end) =
detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
size_t n = num_group * p_fmat->Info().num_row_;
const auto &base_margin = p_fmat->Info().base_margin_.ConstHostVector();
auto& out_preds = p_out_preds->predictions.HostVector();
@ -623,26 +649,24 @@ class Dart : public GBTree {
}
const int nthread = omp_get_max_threads();
InitThreadTemp(nthread);
PredLoopSpecalize(p_fmat, &out_preds, num_group, 0, ntree_limit);
PredLoopSpecalize(p_fmat, &out_preds, num_group, tree_begin, tree_end);
}
void PredictInstance(const SparsePage::Inst &inst,
std::vector<bst_float> *out_preds,
unsigned ntree_limit) override {
unsigned layer_begin, unsigned layer_end) override {
DropTrees(false);
if (thread_temp_.size() == 0) {
thread_temp_.resize(1, RegTree::FVec());
thread_temp_[0].Init(model_.learner_model_param->num_feature);
}
out_preds->resize(model_.learner_model_param->num_output_group);
ntree_limit *= model_.learner_model_param->num_output_group;
if (ntree_limit == 0 || ntree_limit > model_.trees.size()) {
ntree_limit = static_cast<unsigned>(model_.trees.size());
}
uint32_t tree_begin, tree_end;
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
// loop over output groups
for (uint32_t gid = 0; gid < model_.learner_model_param->num_output_group; ++gid) {
(*out_preds)[gid] =
PredValue(inst, gid, &thread_temp_[0], 0, ntree_limit) +
PredValue(inst, gid, &thread_temp_[0], 0, tree_end) +
model_.learner_model_param->base_score;
}
}
@ -653,22 +677,25 @@ class Dart : public GBTree {
void PredictContribution(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_contribs,
unsigned ntree_limit, bool approximate, int,
unsigned layer_begin, unsigned layer_end, bool approximate, int,
unsigned) override {
CHECK(configured_);
uint32_t tree_begin, tree_end;
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
cpu_predictor_->PredictContribution(p_fmat, out_contribs, model_,
ntree_limit, &weight_drop_, approximate);
tree_end, &weight_drop_, approximate);
}
void PredictInteractionContributions(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_contribs,
unsigned ntree_limit, bool approximate) override {
void PredictInteractionContributions(
DMatrix *p_fmat, HostDeviceVector<bst_float> *out_contribs,
unsigned layer_begin, unsigned layer_end, bool approximate) override {
CHECK(configured_);
uint32_t tree_begin, tree_end;
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
cpu_predictor_->PredictInteractionContributions(p_fmat, out_contribs, model_,
ntree_limit, &weight_drop_, approximate);
tree_end, &weight_drop_, approximate);
}
protected:
inline void PredLoopSpecalize(
DMatrix* p_fmat,

View File

@ -164,7 +164,9 @@ inline std::pair<uint32_t, uint32_t> LayerToTree(gbm::GBTreeModel const &model,
if (tree_end == 0) {
tree_end = static_cast<uint32_t>(model.trees.size());
}
CHECK_LT(tree_begin, tree_end);
if (model.trees.size() != 0) {
CHECK_LE(tree_begin, tree_end);
}
return {tree_begin, tree_end};
}
@ -260,10 +262,8 @@ class GBTree : public GradientBooster {
return model_.trees.size() / this->LayerTrees();
}
void PredictBatch(DMatrix* p_fmat,
PredictionCacheEntry* out_preds,
bool training,
unsigned ntree_limit) override;
void PredictBatch(DMatrix *p_fmat, PredictionCacheEntry *out_preds,
bool training, unsigned layer_begin, unsigned layer_end) override;
void InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
float missing, PredictionCacheEntry *out_preds,
@ -297,33 +297,49 @@ class GBTree : public GradientBooster {
void PredictInstance(const SparsePage::Inst& inst,
std::vector<bst_float>* out_preds,
unsigned ntree_limit) override {
uint32_t layer_begin, uint32_t layer_end) override {
CHECK(configured_);
uint32_t tree_begin, tree_end;
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
cpu_predictor_->PredictInstance(inst, out_preds, model_,
ntree_limit);
tree_end);
}
void PredictLeaf(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_preds,
unsigned ntree_limit) override {
this->GetPredictor()->PredictLeaf(p_fmat, out_preds, model_, ntree_limit);
uint32_t layer_begin, uint32_t layer_end) override {
uint32_t tree_begin, tree_end;
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
CHECK_EQ(tree_begin, 0) << "Predict leaf supports only iteration end: (0, "
"n_iteration), use model slicing instead.";
this->GetPredictor()->PredictLeaf(p_fmat, out_preds, model_, tree_end);
}
void PredictContribution(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_contribs,
unsigned ntree_limit, bool approximate,
uint32_t layer_begin, uint32_t layer_end, bool approximate,
int, unsigned) override {
CHECK(configured_);
uint32_t tree_begin, tree_end;
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
CHECK_EQ(tree_begin, 0)
<< "Predict contribution supports only iteration end: (0, "
"n_iteration), using model slicing instead.";
this->GetPredictor()->PredictContribution(
p_fmat, out_contribs, model_, ntree_limit, nullptr, approximate);
p_fmat, out_contribs, model_, tree_end, nullptr, approximate);
}
void PredictInteractionContributions(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_contribs,
unsigned ntree_limit, bool approximate) override {
void PredictInteractionContributions(
DMatrix *p_fmat, HostDeviceVector<bst_float> *out_contribs,
uint32_t layer_begin, uint32_t layer_end, bool approximate) override {
CHECK(configured_);
this->GetPredictor()->PredictInteractionContributions(p_fmat, out_contribs, model_,
ntree_limit, nullptr, approximate);
uint32_t tree_begin, tree_end;
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
CHECK_EQ(tree_begin, 0)
<< "Predict interaction contribution supports only iteration end: (0, "
"n_iteration), using model slicing instead.";
this->GetPredictor()->PredictInteractionContributions(
p_fmat, out_contribs, model_, tree_end, nullptr, approximate);
}
std::vector<std::string> DumpModel(const FeatureMap& fmap,

View File

@ -22,6 +22,7 @@
#include "dmlc/any.h"
#include "xgboost/base.h"
#include "xgboost/c_api.h"
#include "xgboost/data.h"
#include "xgboost/model.h"
#include "xgboost/predictor.h"
@ -996,7 +997,7 @@ class LearnerImpl : public LearnerIO {
auto& predt = local_cache->Cache(train, generic_parameters_.gpu_id);
monitor_.Start("PredictRaw");
this->PredictRaw(train.get(), &predt, true);
this->PredictRaw(train.get(), &predt, true, 0, 0);
TrainingObserver::Instance().Observe(predt.predictions, "Predictions");
monitor_.Stop("PredictRaw");
@ -1057,7 +1058,7 @@ class LearnerImpl : public LearnerIO {
std::shared_ptr<DMatrix> m = data_sets[i];
auto &predt = local_cache->Cache(m, generic_parameters_.gpu_id);
this->ValidateDMatrix(m.get(), false);
this->PredictRaw(m.get(), &predt, false);
this->PredictRaw(m.get(), &predt, false, 0, 0);
auto &out = output_predictions_.Cache(m, generic_parameters_.gpu_id).predictions;
out.Resize(predt.predictions.Size());
@ -1075,8 +1076,8 @@ class LearnerImpl : public LearnerIO {
}
void Predict(std::shared_ptr<DMatrix> data, bool output_margin,
HostDeviceVector<bst_float>* out_preds, unsigned ntree_limit,
bool training,
HostDeviceVector<bst_float> *out_preds, unsigned layer_begin,
unsigned layer_end, bool training,
bool pred_leaf, bool pred_contribs, bool approx_contribs,
bool pred_interactions) override {
int multiple_predictions = static_cast<int>(pred_leaf) +
@ -1085,16 +1086,16 @@ class LearnerImpl : public LearnerIO {
this->Configure();
CHECK_LE(multiple_predictions, 1) << "Perform one kind of prediction at a time.";
if (pred_contribs) {
gbm_->PredictContribution(data.get(), out_preds, ntree_limit, approx_contribs);
gbm_->PredictContribution(data.get(), out_preds, layer_begin, layer_end, approx_contribs);
} else if (pred_interactions) {
gbm_->PredictInteractionContributions(data.get(), out_preds, ntree_limit,
gbm_->PredictInteractionContributions(data.get(), out_preds, layer_begin, layer_end,
approx_contribs);
} else if (pred_leaf) {
gbm_->PredictLeaf(data.get(), out_preds, ntree_limit);
gbm_->PredictLeaf(data.get(), out_preds, layer_begin, layer_end);
} else {
auto local_cache = this->GetPredictionCache();
auto& prediction = local_cache->Cache(data, generic_parameters_.gpu_id);
this->PredictRaw(data.get(), &prediction, training, ntree_limit);
this->PredictRaw(data.get(), &prediction, training, layer_begin, layer_end);
// Copy the prediction cache to output prediction. out_preds comes from C API
out_preds->SetDevice(generic_parameters_.gpu_id);
out_preds->Resize(prediction.predictions.Size());
@ -1151,12 +1152,11 @@ class LearnerImpl : public LearnerIO {
* predictor, when it equals 0, this means we are using all the trees
* \param training allow dropout when the DART booster is being used
*/
void PredictRaw(DMatrix* data, PredictionCacheEntry* out_preds,
bool training,
unsigned ntree_limit = 0) const {
void PredictRaw(DMatrix *data, PredictionCacheEntry *out_preds, bool training,
unsigned layer_begin, unsigned layer_end) const {
CHECK(gbm_ != nullptr) << "Predict must happen after Load or configuration";
this->ValidateDMatrix(data, false);
gbm_->PredictBatch(data, out_preds, training, ntree_limit);
gbm_->PredictBatch(data, out_preds, training, layer_begin, layer_end);
}
void ValidateDMatrix(DMatrix* p_fmat, bool is_training) const {

View File

@ -234,56 +234,28 @@ class CPUPredictor : public Predictor {
public:
explicit CPUPredictor(GenericParameter const* generic_param) :
Predictor::Predictor{generic_param} {}
// ntree_limit is a very problematic parameter, as it's ambiguous in the context of
// multi-output and forest. Same problem exists for tree_begin
void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts,
const gbm::GBTreeModel& model, int tree_begin,
uint32_t const ntree_limit = 0) const override {
// tree_begin is not used, right now we just enforce it to be 0.
CHECK_EQ(tree_begin, 0);
void PredictBatch(DMatrix *dmat, PredictionCacheEntry *predts,
const gbm::GBTreeModel &model, uint32_t tree_begin,
uint32_t tree_end = 0) const override {
auto* out_preds = &predts->predictions;
CHECK_GE(predts->version, tree_begin);
if (out_preds->Size() == 0 && dmat->Info().num_row_ != 0) {
CHECK_EQ(predts->version, 0);
}
// This is actually already handled in gbm, but large amount of tests rely on the
// behaviour.
if (tree_end == 0) {
tree_end = model.trees.size();
}
if (predts->version == 0) {
// out_preds->Size() can be non-zero as it's initialized here before any tree is
// built at the 0^th iterator.
this->InitOutPredictions(dmat->Info(), out_preds, model);
}
uint32_t const output_groups = model.learner_model_param->num_output_group;
CHECK_NE(output_groups, 0);
// Right now we just assume ntree_limit provided by users means number of tree layers
// in the context of multi-output model
uint32_t real_ntree_limit = ntree_limit * output_groups;
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
real_ntree_limit = static_cast<uint32_t>(model.trees.size());
if (tree_end - tree_begin == 0) {
return;
}
uint32_t const end_version = (tree_begin + real_ntree_limit) / output_groups;
// When users have provided ntree_limit, end_version can be lesser, cache is violated
if (predts->version > end_version) {
CHECK_NE(ntree_limit, 0);
this->InitOutPredictions(dmat->Info(), out_preds, model);
predts->version = 0;
}
uint32_t const beg_version = predts->version;
CHECK_LE(beg_version, end_version);
if (beg_version < end_version) {
this->PredictDMatrix(dmat, &out_preds->HostVector(), model,
beg_version * output_groups,
end_version * output_groups);
}
// delta means {size of forest} * {number of newly accumulated layers}
uint32_t delta = end_version - beg_version;
CHECK_LE(delta, model.trees.size());
predts->Update(delta);
CHECK(out_preds->Size() == output_groups * dmat->Info().num_row_ ||
out_preds->Size() == dmat->Info().num_row_);
this->PredictDMatrix(dmat, &out_preds->HostVector(), model, tree_begin,
tree_end);
}
template <typename Adapter>
@ -362,7 +334,6 @@ class CPUPredictor : public Predictor {
InitThreadTemp(nthread, model.learner_model_param->num_feature, &feat_vecs);
const MetaInfo& info = p_fmat->Info();
// number of valid trees
ntree_limit *= model.learner_model_param->num_output_group;
if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
ntree_limit = static_cast<unsigned>(model.trees.size());
}
@ -398,7 +369,6 @@ class CPUPredictor : public Predictor {
InitThreadTemp(nthread, model.learner_model_param->num_feature, &feat_vecs);
const MetaInfo& info = p_fmat->Info();
// number of valid trees
ntree_limit *= model.learner_model_param->num_output_group;
if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
ntree_limit = static_cast<unsigned>(model.trees.size());
}

View File

@ -536,6 +536,7 @@ class GPUPredictor : public xgboost::Predictor {
const uint32_t BLOCK_THREADS = 256;
size_t num_rows = batch.n_rows;
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS));
DeviceModel d_model;
bool use_shared = false;
size_t entry_start = 0;
@ -593,54 +594,27 @@ class GPUPredictor : public xgboost::Predictor {
}
void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts,
const gbm::GBTreeModel& model, int tree_begin,
unsigned ntree_limit = 0) const override {
// This function is duplicated with CPU predictor PredictBatch, see comments in there.
// FIXME(trivialfis): Remove the duplication.
const gbm::GBTreeModel& model, uint32_t tree_begin,
uint32_t tree_end = 0) const override {
int device = generic_param_->gpu_id;
CHECK_GE(device, 0) << "Set `gpu_id' to positive value for processing GPU data.";
ConfigureDevice(device);
CHECK_EQ(tree_begin, 0);
auto* out_preds = &predts->predictions;
CHECK_GE(predts->version, tree_begin);
if (out_preds->Size() == 0 && dmat->Info().num_row_ != 0) {
CHECK_EQ(predts->version, 0);
}
if (tree_end == 0) {
tree_end = model.trees.size();
}
if (predts->version == 0) {
// out_preds->Size() can be non-zero as it's initialized here before any tree is
// built at the 0^th iterator.
this->InitOutPredictions(dmat->Info(), out_preds, model);
}
uint32_t const output_groups = model.learner_model_param->num_output_group;
CHECK_NE(output_groups, 0);
uint32_t real_ntree_limit = ntree_limit * output_groups;
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
real_ntree_limit = static_cast<uint32_t>(model.trees.size());
if (tree_end - tree_begin == 0) {
return;
}
uint32_t const end_version = (tree_begin + real_ntree_limit) / output_groups;
if (predts->version > end_version) {
CHECK_NE(ntree_limit, 0);
this->InitOutPredictions(dmat->Info(), out_preds, model);
predts->version = 0;
}
uint32_t const beg_version = predts->version;
CHECK_LE(beg_version, end_version);
if (beg_version < end_version) {
this->DevicePredictInternal(dmat, out_preds, model,
beg_version * output_groups,
end_version * output_groups);
}
uint32_t delta = end_version - beg_version;
CHECK_LE(delta, model.trees.size());
predts->Update(delta);
CHECK(out_preds->Size() == output_groups * dmat->Info().num_row_ ||
out_preds->Size() == dmat->Info().num_row_);
this->DevicePredictInternal(dmat, out_preds, model, tree_begin, tree_end);
}
template <typename Adapter, typename Loader>
@ -648,15 +622,12 @@ class GPUPredictor : public xgboost::Predictor {
const gbm::GBTreeModel &model, float,
PredictionCacheEntry *out_preds,
uint32_t tree_begin, uint32_t tree_end) const {
auto max_shared_memory_bytes = dh::MaxSharedMemory(this->generic_param_->gpu_id);
uint32_t const output_groups = model.learner_model_param->num_output_group;
DeviceModel d_model;
d_model.Init(model, tree_begin, tree_end, this->generic_param_->gpu_id);
auto m = dmlc::get<std::shared_ptr<Adapter>>(x);
CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature)
<< "Number of columns in data must equal to trained model.";
CHECK_EQ(this->generic_param_->gpu_id, m->DeviceIdx())
CHECK_EQ(dh::CurrentDevice(), m->DeviceIdx())
<< "XGBoost is running on device: " << this->generic_param_->gpu_id << ", "
<< "but data is on: " << m->DeviceIdx();
if (p_m) {
@ -667,12 +638,17 @@ class GPUPredictor : public xgboost::Predictor {
info.num_row_ = m->NumRows();
this->InitOutPredictions(info, &(out_preds->predictions), model);
}
out_preds->predictions.SetDevice(m->DeviceIdx());
const uint32_t BLOCK_THREADS = 128;
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(m->NumRows(), BLOCK_THREADS));
auto max_shared_memory_bytes = dh::MaxSharedMemory(m->DeviceIdx());
size_t shared_memory_bytes =
SharedMemoryBytes<BLOCK_THREADS>(m->NumColumns(), max_shared_memory_bytes);
DeviceModel d_model;
d_model.Init(model, tree_begin, tree_end, m->DeviceIdx());
bool use_shared = shared_memory_bytes != 0;
size_t entry_start = 0;
@ -707,20 +683,17 @@ class GPUPredictor : public xgboost::Predictor {
void PredictContribution(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, unsigned ntree_limit,
const gbm::GBTreeModel& model, unsigned tree_end,
std::vector<bst_float>*,
bool approximate, int,
unsigned) const override {
if (approximate) {
LOG(FATAL) << "Approximated contribution is not implemented in GPU Predictor.";
}
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
out_contribs->SetDevice(generic_param_->gpu_id);
uint32_t real_ntree_limit =
ntree_limit * model.learner_model_param->num_output_group;
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
real_ntree_limit = static_cast<uint32_t>(model.trees.size());
if (tree_end == 0 || tree_end > model.trees.size()) {
tree_end = static_cast<uint32_t>(model.trees.size());
}
const int ngroup = model.learner_model_param->num_output_group;
@ -734,8 +707,7 @@ class GPUPredictor : public xgboost::Predictor {
auto phis = out_contribs->DeviceSpan();
dh::device_vector<gpu_treeshap::PathElement> device_paths;
ExtractPaths(&device_paths, model, real_ntree_limit,
generic_param_->gpu_id);
ExtractPaths(&device_paths, model, tree_end, generic_param_->gpu_id);
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(generic_param_->gpu_id);
batch.offset.SetDevice(generic_param_->gpu_id);
@ -761,20 +733,17 @@ class GPUPredictor : public xgboost::Predictor {
void PredictInteractionContributions(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model,
unsigned ntree_limit,
unsigned tree_end,
std::vector<bst_float>*,
bool approximate) const override {
if (approximate) {
LOG(FATAL) << "[Internal error]: " << __func__
<< " approximate is not implemented in GPU Predictor.";
}
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
out_contribs->SetDevice(generic_param_->gpu_id);
uint32_t real_ntree_limit =
ntree_limit * model.learner_model_param->num_output_group;
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
real_ntree_limit = static_cast<uint32_t>(model.trees.size());
if (tree_end == 0 || tree_end > model.trees.size()) {
tree_end = static_cast<uint32_t>(model.trees.size());
}
const int ngroup = model.learner_model_param->num_output_group;
@ -789,8 +758,7 @@ class GPUPredictor : public xgboost::Predictor {
auto phis = out_contribs->DeviceSpan();
dh::device_vector<gpu_treeshap::PathElement> device_paths;
ExtractPaths(&device_paths, model, real_ntree_limit,
generic_param_->gpu_id);
ExtractPaths(&device_paths, model, tree_end, generic_param_->gpu_id);
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(generic_param_->gpu_id);
batch.offset.SetDevice(generic_param_->gpu_id);
@ -841,29 +809,28 @@ class GPUPredictor : public xgboost::Predictor {
<< " is not implemented in GPU Predictor.";
}
void PredictLeaf(DMatrix* p_fmat, HostDeviceVector<bst_float>* predictions,
const gbm::GBTreeModel& model,
unsigned ntree_limit) const override {
void PredictLeaf(DMatrix *p_fmat, HostDeviceVector<bst_float> *predictions,
const gbm::GBTreeModel &model,
unsigned tree_end) const override {
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
auto max_shared_memory_bytes = ConfigureDevice(generic_param_->gpu_id);
const MetaInfo& info = p_fmat->Info();
constexpr uint32_t kBlockThreads = 128;
size_t shared_memory_bytes =
SharedMemoryBytes<kBlockThreads>(info.num_col_, max_shared_memory_bytes);
size_t shared_memory_bytes = SharedMemoryBytes<kBlockThreads>(
info.num_col_, max_shared_memory_bytes);
bool use_shared = shared_memory_bytes != 0;
bst_feature_t num_features = info.num_col_;
bst_row_t num_rows = info.num_row_;
size_t entry_start = 0;
uint32_t real_ntree_limit = ntree_limit * model.learner_model_param->num_output_group;
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
real_ntree_limit = static_cast<uint32_t>(model.trees.size());
if (tree_end == 0 || tree_end > model.trees.size()) {
tree_end = static_cast<uint32_t>(model.trees.size());
}
predictions->SetDevice(generic_param_->gpu_id);
predictions->Resize(num_rows * real_ntree_limit);
predictions->Resize(num_rows * tree_end);
DeviceModel d_model;
d_model.Init(model, 0, real_ntree_limit, this->generic_param_->gpu_id);
d_model.Init(model, 0, tree_end, this->generic_param_->gpu_id);
if (p_fmat->PageExists<SparsePage>()) {
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {

View File

@ -34,6 +34,7 @@ dependencies:
- llvmlite
- pip:
- shap
- ipython # required by shap at import time.
- guzzle_sphinx_theme
- datatable
- modin[all]

View File

@ -51,6 +51,53 @@ TEST(GBTree, SelectTreeMethod) {
#endif // XGBOOST_USE_CUDA
}
TEST(GBTree, PredictionCache) {
size_t constexpr kRows = 100, kCols = 10;
GenericParameter generic_param;
generic_param.UpdateAllowUnknown(Args{});
LearnerModelParam mparam;
mparam.base_score = 0.5;
mparam.num_feature = kCols;
mparam.num_output_group = 1;
std::unique_ptr<GradientBooster> p_gbm {
GradientBooster::Create("gbtree", &generic_param, &mparam)};
auto& gbtree = dynamic_cast<gbm::GBTree&> (*p_gbm);
gbtree.Configure({{"tree_method", "hist"}});
auto p_m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix();
auto gpair = GenerateRandomGradients(kRows);
PredictionCacheEntry out_predictions;
gbtree.DoBoost(p_m.get(), &gpair, &out_predictions);
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 0);
ASSERT_EQ(1, out_predictions.version);
std::vector<float> first_iter = out_predictions.predictions.HostVector();
// Add 1 more boosted round
gbtree.DoBoost(p_m.get(), &gpair, &out_predictions);
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 0);
ASSERT_EQ(2, out_predictions.version);
// Update the cache for all rounds
out_predictions.version = 0;
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 0);
ASSERT_EQ(2, out_predictions.version);
gbtree.DoBoost(p_m.get(), &gpair, &out_predictions);
// drop the cache.
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 1, 2);
ASSERT_EQ(0, out_predictions.version);
// half open set [1, 3)
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 1, 3);
ASSERT_EQ(0, out_predictions.version);
// iteration end
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 2);
ASSERT_EQ(2, out_predictions.version);
// restart the cache when end iteration is smaller than cache version
gbtree.PredictBatch(p_m.get(), &out_predictions, false, 0, 1);
ASSERT_EQ(1, out_predictions.version);
ASSERT_EQ(out_predictions.predictions.HostVector(), first_iter);
}
TEST(GBTree, WrongUpdater) {
size_t constexpr kRows = 17;
size_t constexpr kCols = 15;

View File

@ -32,7 +32,7 @@ TEST(CpuPredictor, Basic) {
// Test predict batch
PredictionCacheEntry out_predictions;
cpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0);
ASSERT_EQ(model.trees.size(), out_predictions.version);
std::vector<float>& out_predictions_h = out_predictions.predictions.HostVector();
for (size_t i = 0; i < out_predictions.predictions.Size(); i++) {
ASSERT_EQ(out_predictions_h[i], 1.5);
@ -215,7 +215,7 @@ TEST(CpuPredictor, UpdatePredictionCache) {
PredictionCacheEntry out_predictions;
// perform fair prediction on the same input data, should be equal to cached result
gbm->PredictBatch(dmat.get(), &out_predictions, false, 0);
gbm->PredictBatch(dmat.get(), &out_predictions, false, 0, 0);
std::vector<float> &out_predictions_h = out_predictions.predictions.HostVector();
std::vector<float> &predtion_cache_from_train = predtion_cache.predictions.HostVector();

View File

@ -45,7 +45,6 @@ TEST(GPUPredictor, Basic) {
PredictionCacheEntry cpu_out_predictions;
gpu_predictor->PredictBatch(dmat.get(), &gpu_out_predictions, model, 0);
ASSERT_EQ(model.trees.size(), gpu_out_predictions.version);
cpu_predictor->PredictBatch(dmat.get(), &cpu_out_predictions, model, 0);
std::vector<float>& gpu_out_predictions_h = gpu_out_predictions.predictions.HostVector();

View File

@ -64,10 +64,10 @@ void TestTrainingPrediction(size_t rows, size_t bins,
}
HostDeviceVector<float> from_full;
learner->Predict(p_full, false, &from_full);
learner->Predict(p_full, false, &from_full, 0, 0);
HostDeviceVector<float> from_hist;
learner->Predict(p_hist, false, &from_hist);
learner->Predict(p_hist, false, &from_hist, 0, 0);
for (size_t i = 0; i < rows; ++i) {
EXPECT_NEAR(from_hist.ConstHostVector()[i],
@ -157,20 +157,20 @@ void TestPredictionWithLesserFeatures(std::string predictor_name) {
learner->SaveConfig(&config);
ASSERT_EQ(get<String>(config["learner"]["gradient_booster"]["gbtree_train_param"]["predictor"]), predictor_name);
learner->Predict(m_test, false, &prediction);
learner->Predict(m_test, false, &prediction, 0, 0);
ASSERT_EQ(prediction.Size(), kRows);
auto m_invalid = RandomDataGenerator(kRows, kTrainCols + 1, 0.5).GenerateDMatrix(false);
ASSERT_THROW({learner->Predict(m_invalid, false, &prediction);}, dmlc::Error);
ASSERT_THROW({learner->Predict(m_invalid, false, &prediction, 0, 0);}, dmlc::Error);
#if defined(XGBOOST_USE_CUDA)
HostDeviceVector<float> from_cpu;
learner->SetParam("predictor", "cpu_predictor");
learner->Predict(m_test, false, &from_cpu);
learner->Predict(m_test, false, &from_cpu, 0, 0);
HostDeviceVector<float> from_cuda;
learner->SetParam("predictor", "gpu_predictor");
learner->Predict(m_test, false, &from_cuda);
learner->Predict(m_test, false, &from_cuda, 0, 0);
auto const& h_cpu = from_cpu.ConstHostVector();
auto const& h_gpu = from_cuda.ConstHostVector();

View File

@ -221,9 +221,10 @@ TEST(Learner, MultiThreadedPredict) {
auto &entry = learner->GetThreadLocal().prediction_entry;
HostDeviceVector<float> predictions;
for (size_t iter = 0; iter < kIters; ++iter) {
learner->Predict(p_data, false, &entry.predictions);
learner->Predict(p_data, false, &predictions, 0, true); // leaf
learner->Predict(p_data, false, &predictions, 0, false, true); // contribs
learner->Predict(p_data, false, &entry.predictions, 0, 0);
learner->Predict(p_data, false, &predictions, 0, 0, false, true); // leaf
learner->Predict(p_data, false, &predictions, 0, 0, false, false, true); // contribs
}
});
}

View File

@ -112,17 +112,24 @@ def _test_cupy_metainfo(DMatrixT):
@pytest.mark.skipif(**tm.no_sklearn())
def test_cupy_training_with_sklearn():
import cupy as cp
np.random.seed(1)
cp.random.seed(1)
X = cp.random.randn(50, 10, dtype='float32')
y = (cp.random.randn(50, dtype='float32') > 0).astype('int8')
X = cp.random.randn(50, 10, dtype="float32")
y = (cp.random.randn(50, dtype="float32") > 0).astype("int8")
weights = np.random.random(50) + 1
cupy_weights = cp.array(weights)
base_margin = np.random.random(50)
cupy_base_margin = cp.array(base_margin)
clf = xgb.XGBClassifier(gpu_id=0, tree_method='gpu_hist', use_label_encoder=False)
clf.fit(X, y, sample_weight=cupy_weights, base_margin=cupy_base_margin, eval_set=[(X, y)])
clf = xgb.XGBClassifier(gpu_id=0, tree_method="gpu_hist", use_label_encoder=False)
clf.fit(
X,
y,
sample_weight=cupy_weights,
base_margin=cupy_base_margin,
eval_set=[(X, y)],
)
pred = clf.predict(X)
assert np.array_equal(np.unique(pred), np.array([0, 1]))

View File

@ -16,13 +16,15 @@ if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
sys.path.append("tests/python")
from test_with_dask import run_empty_dmatrix_reg # noqa
from test_with_dask import run_empty_dmatrix_cls # noqa
from test_with_dask import _get_client_workers # noqa
from test_with_dask import generate_array # noqa
from test_with_dask import kCols as random_cols # noqa
from test_with_dask import suppress # noqa
import testing as tm # noqa
from test_with_dask import run_empty_dmatrix_reg # noqa
from test_with_dask import run_boost_from_prediction # noqa
from test_with_dask import run_dask_classifier # noqa
from test_with_dask import run_empty_dmatrix_cls # noqa
from test_with_dask import _get_client_workers # noqa
from test_with_dask import generate_array # noqa
from test_with_dask import kCols as random_cols # noqa
from test_with_dask import suppress # noqa
import testing as tm # noqa
try:
@ -132,9 +134,9 @@ def run_gpu_hist(
num_rounds: int,
dataset: tm.TestDataset,
DMatrixT: Type,
client: Client
client: Client,
) -> None:
params['tree_method'] = 'gpu_hist'
params["tree_method"] = "gpu_hist"
params = dataset.set_params(params)
# It doesn't make sense to distribute a completely
# empty dataset.
@ -143,26 +145,40 @@ def run_gpu_hist(
chunk = 128
X = to_cp(dataset.X, DMatrixT)
X = da.from_array(X,
chunks=(chunk, dataset.X.shape[1]))
X = da.from_array(X, chunks=(chunk, dataset.X.shape[1]))
y = to_cp(dataset.y, DMatrixT)
y = da.from_array(y, chunks=(chunk, ))
y = da.from_array(y, chunks=(chunk,))
if dataset.w is not None:
w = to_cp(dataset.w, DMatrixT)
w = da.from_array(w, chunks=(chunk, ))
w = da.from_array(w, chunks=(chunk,))
else:
w = None
if DMatrixT is dxgb.DaskDeviceQuantileDMatrix:
m = DMatrixT(client, data=X, label=y, weight=w,
max_bin=params.get('max_bin', 256))
m = DMatrixT(
client, data=X, label=y, weight=w, max_bin=params.get("max_bin", 256)
)
else:
m = DMatrixT(client, data=X, label=y, weight=w)
history = dxgb.train(client, params=params, dtrain=m,
num_boost_round=num_rounds,
evals=[(m, 'train')])['history']
history = dxgb.train(
client,
params=params,
dtrain=m,
num_boost_round=num_rounds,
evals=[(m, "train")],
)["history"]
note(history)
assert tm.non_increasing(history['train'][dataset.metric])
assert tm.non_increasing(history["train"][dataset.metric])
def test_boost_from_prediction(local_cuda_cluster: LocalCUDACluster) -> None:
import cudf
from sklearn.datasets import load_breast_cancer
with Client(local_cuda_cluster) as client:
X_, y_ = load_breast_cancer(return_X_y=True)
X = dd.from_array(X_, chunksize=100).map_partitions(cudf.from_pandas)
y = dd.from_array(y_, chunksize=100).map_partitions(cudf.from_pandas)
run_boost_from_prediction(X, y, "gpu_hist", client)
class TestDistributedGPU:
@ -246,6 +262,20 @@ class TestDistributedGPU:
dump = booster.get_dump(dump_format='json')
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
@pytest.mark.skipif(**tm.no_cudf())
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.parametrize("model", ["boosting"])
def test_dask_classifier(self, model, local_cuda_cluster: LocalCUDACluster) -> None:
import dask_cudf
with Client(local_cuda_cluster) as client:
X_, y_, w_ = generate_array(with_weights=True)
y_ = (y_ * 10).astype(np.int32)
X = dask_cudf.from_dask_dataframe(dd.from_dask_array(X_))
y = dask_cudf.from_dask_dataframe(dd.from_dask_array(y_))
w = dask_cudf.from_dask_dataframe(dd.from_dask_array(w_))
run_dask_classifier(X, y, w, model, client)
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.mgpu

View File

@ -434,7 +434,13 @@ class TestModels:
booster[...:end] = booster
sliced_0 = booster[1:3]
np.testing.assert_allclose(
booster.predict(dtrain, iteration_range=(1, 3)), sliced_0.predict(dtrain)
)
sliced_1 = booster[3:7]
np.testing.assert_allclose(
booster.predict(dtrain, iteration_range=(3, 7)), sliced_1.predict(dtrain)
)
predt_0 = sliced_0.predict(dtrain, output_margin=True)
predt_1 = sliced_1.predict(dtrain, output_margin=True)

View File

@ -47,30 +47,27 @@ def run_predict_leaf(predictor):
empty_leaf = booster.predict(empty, pred_leaf=True)
assert empty_leaf.shape[0] == 0
leaf = booster.predict(m, pred_leaf=True)
leaf = booster.predict(m, pred_leaf=True, strict_shape=True)
assert leaf.shape[0] == rows
assert leaf.shape[1] == classes * num_parallel_tree * num_boost_round
assert leaf.shape[1] == num_boost_round
assert leaf.shape[2] == classes
assert leaf.shape[3] == num_parallel_tree
for i in range(rows):
row = leaf[i, ...]
for j in range(num_boost_round):
start = classes * num_parallel_tree * j
end = classes * num_parallel_tree * (j + 1)
layer = row[start: end]
for c in range(classes):
tree_group = layer[c * num_parallel_tree: (c + 1) * num_parallel_tree]
for k in range(classes):
tree_group = leaf[i, j, k, :]
assert tree_group.shape[0] == num_parallel_tree
# no subsampling so tree in same forest should output same
# leaf.
# No sampling, all trees within forest are the same
assert np.all(tree_group == tree_group[0])
ntree_limit = 2
sliced = booster.predict(
m, pred_leaf=True, ntree_limit=num_parallel_tree * ntree_limit
m, pred_leaf=True, ntree_limit=num_parallel_tree * ntree_limit, strict_shape=True
)
first = sliced[0, ...]
assert first.shape[0] == classes * num_parallel_tree * ntree_limit
assert np.prod(first.shape) == classes * num_parallel_tree * ntree_limit
return leaf
@ -78,6 +75,23 @@ def test_predict_leaf():
run_predict_leaf('cpu_predictor')
def test_predict_shape():
from sklearn.datasets import load_boston
X, y = load_boston(return_X_y=True)
reg = xgb.XGBRegressor(n_estimators=1)
reg.fit(X, y)
predt = reg.get_booster().predict(xgb.DMatrix(X), strict_shape=True)
assert len(predt.shape) == 2
assert predt.shape[0] == X.shape[0]
assert predt.shape[1] == 1
contrib = reg.get_booster().predict(
xgb.DMatrix(X), pred_contribs=True, strict_shape=True
)
assert len(contrib.shape) == 3
assert contrib.shape[1] == 1
class TestInplacePredict:
'''Tests for running inplace prediction'''
@classmethod
@ -92,8 +106,7 @@ class TestInplacePredict:
dtrain = xgb.DMatrix(cls.X, cls.y)
cls.booster = xgb.train({'tree_method': 'hist'},
dtrain, num_boost_round=10)
cls.booster = xgb.train({'tree_method': 'hist'}, dtrain, num_boost_round=10)
cls.test = xgb.DMatrix(cls.X[:10, ...])

View File

@ -159,12 +159,9 @@ def test_dask_predict_shape_infer(client: "Client") -> None:
assert prediction.shape[1] == 3
@pytest.mark.parametrize("tree_method", ["hist", "approx"])
def test_boost_from_prediction(tree_method: str, client: "Client") -> None:
from sklearn.datasets import load_breast_cancer
X_, y_ = load_breast_cancer(return_X_y=True)
X, y = dd.from_array(X_, chunksize=100), dd.from_array(y_, chunksize=100)
def run_boost_from_prediction(
X: xgb.dask._DaskCollection, y: xgb.dask._DaskCollection, tree_method: str, client: "Client"
) -> None:
model_0 = xgb.dask.DaskXGBClassifier(
learning_rate=0.3, random_state=0, n_estimators=4,
tree_method=tree_method)
@ -202,6 +199,30 @@ def test_boost_from_prediction(tree_method: str, client: "Client") -> None:
assert margined_res[i] < unmargined_res[i]
@pytest.mark.parametrize("tree_method", ["hist", "approx"])
def test_boost_from_prediction(tree_method: str, client: "Client") -> None:
from sklearn.datasets import load_breast_cancer
X_, y_ = load_breast_cancer(return_X_y=True)
X, y = dd.from_array(X_, chunksize=100), dd.from_array(y_, chunksize=100)
run_boost_from_prediction(X, y, tree_method, client)
def test_inplace_predict(client: "Client") -> None:
from sklearn.datasets import load_boston
X_, y_ = load_boston(return_X_y=True)
X, y = dd.from_array(X_, chunksize=32), dd.from_array(y_, chunksize=32)
reg = xgb.dask.DaskXGBRegressor(n_estimators=4).fit(X, y)
booster = reg.get_booster()
base_margin = y
inplace = xgb.dask.inplace_predict(
client, booster, X, base_margin=base_margin
).compute()
Xy = xgb.dask.DaskDMatrix(client, X, base_margin=base_margin)
copied = xgb.dask.predict(client, booster, Xy).compute()
np.testing.assert_allclose(inplace, copied)
def test_dask_missing_value_reg(client: "Client") -> None:
X_0 = np.ones((20 // 2, kCols))
X_1 = np.zeros((20 // 2, kCols))
@ -288,10 +309,13 @@ def test_dask_regressor(model: str, client: "Client") -> None:
assert forest == 2
@pytest.mark.parametrize("model", ["boosting", "rf"])
def test_dask_classifier(model: str, client: "Client") -> None:
X, y, w = generate_array(with_weights=True)
y = (y * 10).astype(np.int32)
def run_dask_classifier(
X: xgb.dask._DaskCollection,
y: xgb.dask._DaskCollection,
w: xgb.dask._DaskCollection,
model: str,
client: "Client",
) -> None:
if model == "boosting":
classifier = xgb.dask.DaskXGBClassifier(
verbosity=1, n_estimators=2, eval_metric="merror"
@ -306,14 +330,13 @@ def test_dask_classifier(model: str, client: "Client") -> None:
classifier.client = client
classifier.fit(X, y, sample_weight=w, eval_set=[(X, y)])
prediction = classifier.predict(X)
prediction = classifier.predict(X).compute()
assert prediction.ndim == 1
assert prediction.shape[0] == kRows
history = classifier.evals_result()
assert isinstance(prediction, da.Array)
assert isinstance(history, dict)
assert list(history.keys())[0] == "validation_0"
@ -332,7 +355,7 @@ def test_dask_classifier(model: str, client: "Client") -> None:
assert forest == 2
# Test .predict_proba()
probas = classifier.predict_proba(X)
probas = classifier.predict_proba(X).compute()
assert classifier.n_classes_ == 10
assert probas.ndim == 2
assert probas.shape[0] == kRows
@ -341,18 +364,33 @@ def test_dask_classifier(model: str, client: "Client") -> None:
cls_booster = classifier.get_booster()
single_node_proba = cls_booster.inplace_predict(X.compute())
np.testing.assert_allclose(single_node_proba, probas.compute())
# test shared by CPU and GPU
if isinstance(single_node_proba, np.ndarray):
np.testing.assert_allclose(single_node_proba, probas)
else:
import cupy
cupy.testing.assert_allclose(single_node_proba, probas)
# Test with dataframe.
X_d = dd.from_dask_array(X)
y_d = dd.from_dask_array(y)
classifier.fit(X_d, y_d)
# Test with dataframe, not shared with GPU as cupy doesn't work well with da.unique.
if isinstance(X, da.Array):
X_d: dd.DataFrame = X.to_dask_dataframe()
assert classifier.n_classes_ == 10
prediction = classifier.predict(X_d).compute()
assert classifier.n_classes_ == 10
prediction_df = classifier.predict(X_d).compute()
assert prediction.ndim == 1
assert prediction.shape[0] == kRows
assert prediction_df.ndim == 1
assert prediction_df.shape[0] == kRows
np.testing.assert_allclose(prediction_df, prediction)
probas = classifier.predict_proba(X).compute()
np.testing.assert_allclose(single_node_proba, probas)
@pytest.mark.parametrize("model", ["boosting", "rf"])
def test_dask_classifier(model: str, client: "Client") -> None:
X, y, w = generate_array(with_weights=True)
y = (y * 10).astype(np.int32)
run_dask_classifier(X, y, w, model, client)
@pytest.mark.skipif(**tm.no_sklearn())
@ -913,9 +951,9 @@ class TestWithDask:
train = xgb.dask.DaskDMatrix(client, dX, dy)
dX = dd.from_array(X)
dX = client.persist(dX, workers={dX: workers[1]})
dX = client.persist(dX, workers=workers[1])
dy = dd.from_array(y)
dy = client.persist(dy, workers={dy: workers[1]})
dy = client.persist(dy, workers=workers[1])
valid = xgb.dask.DaskDMatrix(client, dX, dy)
merged = xgb.dask._get_workers_from_data(train, evals=[(valid, 'Valid')])
@ -1060,6 +1098,16 @@ class TestWithDask:
assert_shape(shap.shape)
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5)
X = dd.from_dask_array(X).repartition(npartitions=32)
y = dd.from_dask_array(y).repartition(npartitions=32)
shap_df = xgb.dask.predict(
client, booster, X, pred_contribs=True, validate_features=False
).compute()
assert_shape(shap_df.shape)
assert np.allclose(
np.sum(shap_df, axis=len(shap_df.shape) - 1), margin, 1e-5, 1e-5
)
def run_shap_cls_sklearn(self, X: Any, y: Any, client: "Client") -> None:
X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32)
cls = xgb.dask.DaskXGBClassifier(n_estimators=4)