Enhance inplace prediction. (#6653)

* Accept array interface for csr and array.
* Accept an optional proxy dmatrix for metainfo.

This constructs an explicit `_ProxyDMatrix` type in Python.

* Remove unused doc.
* Add strict output.
This commit is contained in:
Jiaming Yuan
2021-02-02 11:41:46 +08:00
committed by GitHub
parent 87ab1ad607
commit 411592a347
22 changed files with 955 additions and 530 deletions

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2014-2020 by Contributors
* Copyright 2014-2021 by Contributors
* \file gbm.h
* \brief Interface of gradient booster,
* that learns through gradient statistics.
@@ -118,7 +118,7 @@ class GradientBooster : public Model, public Configurable {
* \param layer_begin (Optional) Begining of boosted tree layer used for prediction.
* \param layer_end (Optional) End of booster layer. 0 means do not limit trees.
*/
virtual void InplacePredict(dmlc::any const &, float,
virtual void InplacePredict(dmlc::any const &, std::shared_ptr<DMatrix>, float,
PredictionCacheEntry*,
uint32_t,
uint32_t) const {

View File

@@ -308,6 +308,7 @@ struct StringView {
public:
StringView() = default;
StringView(CharT const* str, size_t size) : str_{str}, size_{size} {}
explicit StringView(std::string const& str): str_{str.c_str()}, size_{str.size()} {}
explicit StringView(CharT const* str) : str_{str}, size_{Traits::length(str)} {}
CharT const& operator[](size_t p) const { return str_[p]; }

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2015-2020 by Contributors
* Copyright 2015-2021 by Contributors
* \file learner.h
* \brief Learner interface that integrates objective, gbm and evaluation together.
* This is the user facing XGBoost training module.
@@ -30,6 +30,15 @@ class ObjFunction;
class DMatrix;
class Json;
enum class PredictionType : std::uint8_t { // NOLINT
kValue = 0,
kMargin = 1,
kContribution = 2,
kApproxContribution = 3,
kInteraction = 4,
kLeaf = 5
};
/*! \brief entry to to easily hold returning information */
struct XGBAPIThreadLocalEntry {
/*! \brief result holder for returning string */
@@ -42,7 +51,10 @@ struct XGBAPIThreadLocalEntry {
std::vector<bst_float> ret_vec_float;
/*! \brief temp variable of gradient pairs. */
std::vector<GradientPair> tmp_gpair;
/*! \brief Temp variable for returing prediction result. */
PredictionCacheEntry prediction_entry;
/*! \brief Temp variable for returing prediction shape. */
std::vector<bst_ulong> prediction_shape;
};
/*!
@@ -123,13 +135,17 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
* \brief Inplace prediction.
*
* \param x A type erased data adapter.
* \param p_m An optional Proxy DMatrix object storing meta info like
* base margin. Can be nullptr.
* \param type Prediction type.
* \param missing Missing value in the data.
* \param [in,out] out_preds Pointer to output prediction vector.
* \param layer_begin (Optional) Begining of boosted tree layer used for prediction.
* \param layer_end (Optional) End of booster layer. 0 means do not limit trees.
* \param layer_begin Begining of boosted tree layer used for prediction.
* \param layer_end End of booster layer. 0 means do not limit trees.
*/
virtual void InplacePredict(dmlc::any const& x, std::string const& type,
virtual void InplacePredict(dmlc::any const &x,
std::shared_ptr<DMatrix> p_m,
PredictionType type,
float missing,
HostDeviceVector<bst_float> **out_preds,
uint32_t layer_begin, uint32_t layer_end) = 0;
@@ -138,6 +154,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
* \brief Get number of boosted rounds from gradient booster.
*/
virtual int32_t BoostedRounds() const = 0;
virtual uint32_t Groups() const = 0;
void LoadModel(Json const& in) override = 0;
void SaveModel(Json* out) const override = 0;

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2017-2020 by Contributors
* Copyright 2017-2021 by Contributors
* \file predictor.h
* \brief Interface of predictor,
* performs predictions for a gradient booster.
@@ -142,10 +142,14 @@ class Predictor {
* \param [in,out] out_preds The output preds.
* \param tree_begin (Optional) Begining of boosted trees used for prediction.
* \param tree_end (Optional) End of booster trees. 0 means do not limit trees.
*
* \return True if the data can be handled by current predictor, false otherwise.
*/
virtual void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model,
float missing, PredictionCacheEntry *out_preds,
uint32_t tree_begin = 0, uint32_t tree_end = 0) const = 0;
virtual bool InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
const gbm::GBTreeModel &model, float missing,
PredictionCacheEntry *out_preds,
uint32_t tree_begin = 0,
uint32_t tree_end = 0) const = 0;
/**
* \brief online prediction function, predict score for one instance at a time
* NOTE: use the batch prediction interface if possible, batch prediction is