[CORE] Refactor cache mechanism (#1540)
This commit is contained in:
@@ -13,8 +13,10 @@
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include "./base.h"
|
||||
#include "./data.h"
|
||||
#include "./objective.h"
|
||||
#include "./feature_map.h"
|
||||
|
||||
namespace xgboost {
|
||||
@@ -50,13 +52,6 @@ class GradientBooster {
|
||||
* \param fo output stream
|
||||
*/
|
||||
virtual void Save(dmlc::Stream* fo) const = 0;
|
||||
/*!
|
||||
* \brief reset the predict buffer size.
|
||||
* This will invalidate all the previous cached results
|
||||
* and recalculate from scratch
|
||||
* \param num_pbuffer The size of predict buffer.
|
||||
*/
|
||||
virtual void ResetPredBuffer(size_t num_pbuffer) {}
|
||||
/*!
|
||||
* \brief whether the model allow lazy checkpoint
|
||||
* return true if model is only updated in DoBoost
|
||||
@@ -68,27 +63,21 @@ class GradientBooster {
|
||||
/*!
|
||||
* \brief perform update to the model(boosting)
|
||||
* \param p_fmat feature matrix that provide access to features
|
||||
* \param buffer_offset buffer index offset of these instances, if equals -1
|
||||
* this means we do not have buffer index allocated to the gbm
|
||||
* \param in_gpair address of the gradient pair statistics of the data
|
||||
* \param obj The objective function, optional, can be nullptr when use customized version
|
||||
* the booster may change content of gpair
|
||||
*/
|
||||
virtual void DoBoost(DMatrix* p_fmat,
|
||||
int64_t buffer_offset,
|
||||
std::vector<bst_gpair>* in_gpair) = 0;
|
||||
std::vector<bst_gpair>* in_gpair,
|
||||
ObjFunction* obj = nullptr) = 0;
|
||||
/*!
|
||||
* \brief generate predictions for given feature matrix
|
||||
* \param dmat feature matrix
|
||||
* \param buffer_offset buffer index offset of these instances, if equals -1
|
||||
* this means we do not have buffer index allocated to the gbm
|
||||
* a buffer index is assigned to each instance that requires repeative prediction
|
||||
* the size of buffer is set by convention using GradientBooster.ResetPredBuffer(size);
|
||||
* \param out_preds output vector to hold the predictions
|
||||
* \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means
|
||||
* we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear
|
||||
*/
|
||||
virtual void Predict(DMatrix* dmat,
|
||||
int64_t buffer_offset,
|
||||
std::vector<float>* out_preds,
|
||||
unsigned ntree_limit = 0) = 0;
|
||||
/*!
|
||||
@@ -128,9 +117,14 @@ class GradientBooster {
|
||||
/*!
|
||||
* \brief create a gradient booster from given name
|
||||
* \param name name of gradient booster
|
||||
* \param cache_mats The cache data matrix of the Booster.
|
||||
* \param base_margin The base margin of prediction.
|
||||
* \return The created booster.
|
||||
*/
|
||||
static GradientBooster* Create(const std::string& name);
|
||||
static GradientBooster* Create(
|
||||
const std::string& name,
|
||||
const std::vector<std::shared_ptr<DMatrix> >& cache_mats,
|
||||
float base_margin);
|
||||
};
|
||||
|
||||
// implementing configure.
|
||||
@@ -144,8 +138,10 @@ inline void GradientBooster::Configure(PairIter begin, PairIter end) {
|
||||
* \brief Registry entry for tree updater.
|
||||
*/
|
||||
struct GradientBoosterReg
|
||||
: public dmlc::FunctionRegEntryBase<GradientBoosterReg,
|
||||
std::function<GradientBooster* ()> > {
|
||||
: public dmlc::FunctionRegEntryBase<
|
||||
GradientBoosterReg,
|
||||
std::function<GradientBooster* (const std::vector<std::shared_ptr<DMatrix> > &cached_mats,
|
||||
float base_margin)> > {
|
||||
};
|
||||
|
||||
/*!
|
||||
|
||||
@@ -166,7 +166,7 @@ class Learner : public rabit::Serializable {
|
||||
* \param cache_data The matrix to cache the prediction.
|
||||
* \return Created learner.
|
||||
*/
|
||||
static Learner* Create(const std::vector<DMatrix*>& cache_data);
|
||||
static Learner* Create(const std::vector<std::shared_ptr<DMatrix> >& cache_data);
|
||||
|
||||
protected:
|
||||
/*! \brief internal base score of the model */
|
||||
|
||||
Reference in New Issue
Block a user