Enhance inplace prediction. (#6653)
* Accept array interface for csr and array. * Accept an optional proxy dmatrix for metainfo. This constructs an explicit `_ProxyDMatrix` type in Python. * Remove unused doc. * Add strict output.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2014-2020 by Contributors
|
||||
* Copyright 2014-2021 by Contributors
|
||||
* \file gbm.h
|
||||
* \brief Interface of gradient booster,
|
||||
* that learns through gradient statistics.
|
||||
@@ -118,7 +118,7 @@ class GradientBooster : public Model, public Configurable {
|
||||
* \param layer_begin (Optional) Begining of boosted tree layer used for prediction.
|
||||
* \param layer_end (Optional) End of booster layer. 0 means do not limit trees.
|
||||
*/
|
||||
virtual void InplacePredict(dmlc::any const &, float,
|
||||
virtual void InplacePredict(dmlc::any const &, std::shared_ptr<DMatrix>, float,
|
||||
PredictionCacheEntry*,
|
||||
uint32_t,
|
||||
uint32_t) const {
|
||||
|
||||
@@ -308,6 +308,7 @@ struct StringView {
|
||||
public:
|
||||
StringView() = default;
|
||||
StringView(CharT const* str, size_t size) : str_{str}, size_{size} {}
|
||||
explicit StringView(std::string const& str): str_{str.c_str()}, size_{str.size()} {}
|
||||
explicit StringView(CharT const* str) : str_{str}, size_{Traits::length(str)} {}
|
||||
|
||||
CharT const& operator[](size_t p) const { return str_[p]; }
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2015-2020 by Contributors
|
||||
* Copyright 2015-2021 by Contributors
|
||||
* \file learner.h
|
||||
* \brief Learner interface that integrates objective, gbm and evaluation together.
|
||||
* This is the user facing XGBoost training module.
|
||||
@@ -30,6 +30,15 @@ class ObjFunction;
|
||||
class DMatrix;
|
||||
class Json;
|
||||
|
||||
enum class PredictionType : std::uint8_t { // NOLINT
|
||||
kValue = 0,
|
||||
kMargin = 1,
|
||||
kContribution = 2,
|
||||
kApproxContribution = 3,
|
||||
kInteraction = 4,
|
||||
kLeaf = 5
|
||||
};
|
||||
|
||||
/*! \brief entry to to easily hold returning information */
|
||||
struct XGBAPIThreadLocalEntry {
|
||||
/*! \brief result holder for returning string */
|
||||
@@ -42,7 +51,10 @@ struct XGBAPIThreadLocalEntry {
|
||||
std::vector<bst_float> ret_vec_float;
|
||||
/*! \brief temp variable of gradient pairs. */
|
||||
std::vector<GradientPair> tmp_gpair;
|
||||
/*! \brief Temp variable for returing prediction result. */
|
||||
PredictionCacheEntry prediction_entry;
|
||||
/*! \brief Temp variable for returing prediction shape. */
|
||||
std::vector<bst_ulong> prediction_shape;
|
||||
};
|
||||
|
||||
/*!
|
||||
@@ -123,13 +135,17 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
|
||||
* \brief Inplace prediction.
|
||||
*
|
||||
* \param x A type erased data adapter.
|
||||
* \param p_m An optional Proxy DMatrix object storing meta info like
|
||||
* base margin. Can be nullptr.
|
||||
* \param type Prediction type.
|
||||
* \param missing Missing value in the data.
|
||||
* \param [in,out] out_preds Pointer to output prediction vector.
|
||||
* \param layer_begin (Optional) Begining of boosted tree layer used for prediction.
|
||||
* \param layer_end (Optional) End of booster layer. 0 means do not limit trees.
|
||||
* \param layer_begin Begining of boosted tree layer used for prediction.
|
||||
* \param layer_end End of booster layer. 0 means do not limit trees.
|
||||
*/
|
||||
virtual void InplacePredict(dmlc::any const& x, std::string const& type,
|
||||
virtual void InplacePredict(dmlc::any const &x,
|
||||
std::shared_ptr<DMatrix> p_m,
|
||||
PredictionType type,
|
||||
float missing,
|
||||
HostDeviceVector<bst_float> **out_preds,
|
||||
uint32_t layer_begin, uint32_t layer_end) = 0;
|
||||
@@ -138,6 +154,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
|
||||
* \brief Get number of boosted rounds from gradient booster.
|
||||
*/
|
||||
virtual int32_t BoostedRounds() const = 0;
|
||||
virtual uint32_t Groups() const = 0;
|
||||
|
||||
void LoadModel(Json const& in) override = 0;
|
||||
void SaveModel(Json* out) const override = 0;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2017-2020 by Contributors
|
||||
* Copyright 2017-2021 by Contributors
|
||||
* \file predictor.h
|
||||
* \brief Interface of predictor,
|
||||
* performs predictions for a gradient booster.
|
||||
@@ -142,10 +142,14 @@ class Predictor {
|
||||
* \param [in,out] out_preds The output preds.
|
||||
* \param tree_begin (Optional) Begining of boosted trees used for prediction.
|
||||
* \param tree_end (Optional) End of booster trees. 0 means do not limit trees.
|
||||
*
|
||||
* \return True if the data can be handled by current predictor, false otherwise.
|
||||
*/
|
||||
virtual void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model,
|
||||
float missing, PredictionCacheEntry *out_preds,
|
||||
uint32_t tree_begin = 0, uint32_t tree_end = 0) const = 0;
|
||||
virtual bool InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
||||
const gbm::GBTreeModel &model, float missing,
|
||||
PredictionCacheEntry *out_preds,
|
||||
uint32_t tree_begin = 0,
|
||||
uint32_t tree_end = 0) const = 0;
|
||||
/**
|
||||
* \brief online prediction function, predict score for one instance at a time
|
||||
* NOTE: use the batch prediction interface if possible, batch prediction is
|
||||
|
||||
Reference in New Issue
Block a user