Simplify inplace-predict. (#7910)
Pass the `X` as part of Proxy DMatrix instead of an independent `dmlc::any`.
This commit is contained in:
@@ -111,15 +111,14 @@ class GradientBooster : public Model, public Configurable {
|
||||
/*!
|
||||
* \brief Inplace prediction.
|
||||
*
|
||||
* \param x A type erased data adapter.
|
||||
* \param p_fmat A proxy DMatrix that contains the data and related
|
||||
* meta info.
|
||||
* \param missing Missing value in the data.
|
||||
* \param [in,out] out_preds The output preds.
|
||||
* \param layer_begin (Optional) Beginning of boosted tree layer used for prediction.
|
||||
* \param layer_end (Optional) End of booster layer. 0 means do not limit trees.
|
||||
*/
|
||||
virtual void InplacePredict(dmlc::any const &, std::shared_ptr<DMatrix>, float,
|
||||
PredictionCacheEntry*,
|
||||
uint32_t,
|
||||
virtual void InplacePredict(std::shared_ptr<DMatrix>, float, PredictionCacheEntry*, uint32_t,
|
||||
uint32_t) const {
|
||||
LOG(FATAL) << "Inplace predict is not supported by current booster.";
|
||||
}
|
||||
|
||||
@@ -139,21 +139,16 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
|
||||
/*!
|
||||
* \brief Inplace prediction.
|
||||
*
|
||||
* \param x A type erased data adapter.
|
||||
* \param p_m An optional Proxy DMatrix object storing meta info like
|
||||
* base margin. Can be nullptr.
|
||||
* \param p_fmat A proxy DMatrix that contains the data and related meta info.
|
||||
* \param type Prediction type.
|
||||
* \param missing Missing value in the data.
|
||||
* \param [in,out] out_preds Pointer to output prediction vector.
|
||||
* \param layer_begin Beginning of boosted tree layer used for prediction.
|
||||
* \param layer_end End of booster layer. 0 means do not limit trees.
|
||||
*/
|
||||
virtual void InplacePredict(dmlc::any const &x,
|
||||
std::shared_ptr<DMatrix> p_m,
|
||||
PredictionType type,
|
||||
float missing,
|
||||
HostDeviceVector<bst_float> **out_preds,
|
||||
uint32_t layer_begin, uint32_t layer_end) = 0;
|
||||
virtual void InplacePredict(std::shared_ptr<DMatrix> p_m, PredictionType type, float missing,
|
||||
HostDeviceVector<bst_float>** out_preds, uint32_t layer_begin,
|
||||
uint32_t layer_end) = 0;
|
||||
|
||||
/*!
|
||||
* \brief Calculate feature score. See doc in C API for outputs.
|
||||
|
||||
@@ -145,7 +145,9 @@ class Predictor {
|
||||
|
||||
/**
|
||||
* \brief Inplace prediction.
|
||||
* \param x Type erased data adapter.
|
||||
*
|
||||
* \param p_fmat A proxy DMatrix that contains the data and related
|
||||
* meta info.
|
||||
* \param model The model to predict from.
|
||||
* \param missing Missing value in the data.
|
||||
* \param [in,out] out_preds The output preds.
|
||||
@@ -154,11 +156,9 @@ class Predictor {
|
||||
*
|
||||
* \return True if the data can be handled by current predictor, false otherwise.
|
||||
*/
|
||||
virtual bool InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
||||
const gbm::GBTreeModel &model, float missing,
|
||||
PredictionCacheEntry *out_preds,
|
||||
uint32_t tree_begin = 0,
|
||||
uint32_t tree_end = 0) const = 0;
|
||||
virtual bool InplacePredict(std::shared_ptr<DMatrix> p_fmat, const gbm::GBTreeModel& model,
|
||||
float missing, PredictionCacheEntry* out_preds,
|
||||
uint32_t tree_begin = 0, uint32_t tree_end = 0) const = 0;
|
||||
/**
|
||||
* \brief online prediction function, predict score for one instance at a time
|
||||
* NOTE: use the batch prediction interface if possible, batch prediction is
|
||||
|
||||
Reference in New Issue
Block a user