[CORE] Refactor cache mechanism (#1540)

This commit is contained in:
Tianqi Chen
2016-09-02 20:39:07 -07:00
committed by GitHub
parent 6dabdd33e3
commit ecec5f7959
9 changed files with 320 additions and 421 deletions

View File

@@ -13,8 +13,10 @@
#include <utility>
#include <string>
#include <functional>
#include <memory>
#include "./base.h"
#include "./data.h"
#include "./objective.h"
#include "./feature_map.h"
namespace xgboost {
@@ -50,13 +52,6 @@ class GradientBooster {
* \param fo output stream
*/
virtual void Save(dmlc::Stream* fo) const = 0;
/*!
* \brief reset the predict buffer size.
* This will invalidate all the previous cached results
* and recalculate from scratch
* \param num_pbuffer The size of predict buffer.
*/
virtual void ResetPredBuffer(size_t num_pbuffer) {}
/*!
* \brief whether the model allow lazy checkpoint
* return true if model is only updated in DoBoost
@@ -68,27 +63,21 @@ class GradientBooster {
/*!
* \brief perform update to the model(boosting)
* \param p_fmat feature matrix that provide access to features
* \param buffer_offset buffer index offset of these instances, if equals -1
* this means we do not have buffer index allocated to the gbm
* \param in_gpair address of the gradient pair statistics of the data
* \param obj The objective function, optional, can be nullptr when use customized version
* the booster may change content of gpair
*/
virtual void DoBoost(DMatrix* p_fmat,
int64_t buffer_offset,
std::vector<bst_gpair>* in_gpair) = 0;
std::vector<bst_gpair>* in_gpair,
ObjFunction* obj = nullptr) = 0;
/*!
* \brief generate predictions for given feature matrix
* \param dmat feature matrix
* \param buffer_offset buffer index offset of these instances, if equals -1
* this means we do not have buffer index allocated to the gbm
* a buffer index is assigned to each instance that requires repeative prediction
* the size of buffer is set by convention using GradientBooster.ResetPredBuffer(size);
* \param out_preds output vector to hold the predictions
* \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means
* we do not limit number of trees, this parameter is only valid for gbtree, but not for gblinear
*/
virtual void Predict(DMatrix* dmat,
int64_t buffer_offset,
std::vector<float>* out_preds,
unsigned ntree_limit = 0) = 0;
/*!
@@ -128,9 +117,14 @@ class GradientBooster {
/*!
* \brief create a gradient booster from given name
* \param name name of gradient booster
* \param cache_mats The cache data matrix of the Booster.
* \param base_margin The base margin of prediction.
* \return The created booster.
*/
static GradientBooster* Create(const std::string& name);
static GradientBooster* Create(
const std::string& name,
const std::vector<std::shared_ptr<DMatrix> >& cache_mats,
float base_margin);
};
// implementing configure.
@@ -144,8 +138,10 @@ inline void GradientBooster::Configure(PairIter begin, PairIter end) {
* \brief Registry entry for tree updater.
*/
struct GradientBoosterReg
: public dmlc::FunctionRegEntryBase<GradientBoosterReg,
std::function<GradientBooster* ()> > {
: public dmlc::FunctionRegEntryBase<
GradientBoosterReg,
std::function<GradientBooster* (const std::vector<std::shared_ptr<DMatrix> > &cached_mats,
float base_margin)> > {
};
/*!

View File

@@ -166,7 +166,7 @@ class Learner : public rabit::Serializable {
* \param cache_data The matrix to cache the prediction.
* \return Created learner.
*/
static Learner* Create(const std::vector<DMatrix*>& cache_data);
static Learner* Create(const std::vector<std::shared_ptr<DMatrix> >& cache_data);
protected:
/*! \brief internal base score of the model */