[LEARNER] refactor learner
This commit is contained in:
@@ -14,6 +14,9 @@
|
||||
#include "./base.h"
|
||||
|
||||
namespace xgboost {
|
||||
// forward declare learner.
|
||||
class LearnerImpl;
|
||||
|
||||
/*! \brief data type accepted by xgboost interface */
|
||||
enum DataType {
|
||||
kFloat32 = 1,
|
||||
@@ -199,6 +202,8 @@ class DataSource : public dmlc::DataIter<RowBatch> {
|
||||
*/
|
||||
class DMatrix {
|
||||
public:
|
||||
/*! \brief default constructor */
|
||||
DMatrix() : cache_learner_ptr_(nullptr) {}
|
||||
/*! \brief meta information of the dataset */
|
||||
virtual MetaInfo& info() = 0;
|
||||
/*! \brief meta information of the dataset */
|
||||
@@ -222,6 +227,7 @@ class DMatrix {
|
||||
* \param subsample subsample ratio when generating column access.
|
||||
* \param max_row_perbatch auxilary information, maximum row used in each column batch.
|
||||
* this is a hint information that can be ignored by the implementation.
|
||||
* \return Number of column blocks in the column access.
|
||||
*/
|
||||
virtual void InitColAccess(const std::vector<bool>& enabled,
|
||||
float subsample,
|
||||
@@ -229,6 +235,8 @@ class DMatrix {
|
||||
// the following are column meta data, should be able to answer them fast.
|
||||
/*! \return whether column access is enabled */
|
||||
virtual bool HaveColAccess() const = 0;
|
||||
/*! \return Whether the data columns single column block. */
|
||||
virtual bool SingleColBlock() const = 0;
|
||||
/*! \brief get number of non-missing entries in column */
|
||||
virtual size_t GetColSize(size_t cidx) const = 0;
|
||||
/*! \brief get column density */
|
||||
@@ -279,6 +287,12 @@ class DMatrix {
|
||||
*/
|
||||
static DMatrix* Create(dmlc::Parser<uint32_t>* parser,
|
||||
const char* cache_prefix = nullptr);
|
||||
|
||||
private:
|
||||
// allow learner class to access this field.
|
||||
friend class LearnerImpl;
|
||||
/*! \brief public field to back ref cached matrix. */
|
||||
LearnerImpl* cache_learner_ptr_;
|
||||
};
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -25,6 +25,14 @@ class GradientBooster {
|
||||
public:
|
||||
/*! \brief virtual destructor */
|
||||
virtual ~GradientBooster() {}
|
||||
/*!
|
||||
* \brief set configuration from pair iterators.
|
||||
* \param begin The beginning iterator.
|
||||
* \param end The end iterator.
|
||||
* \tparam PairIter iterator<std::pair<std::string, std::string> >
|
||||
*/
|
||||
template<typename PairIter>
|
||||
inline void Configure(PairIter begin, PairIter end);
|
||||
/*!
|
||||
* \brief Set the configuration of gradient boosting.
|
||||
* User must call configure once before InitModel and Training.
|
||||
@@ -123,9 +131,16 @@ class GradientBooster {
|
||||
* \breif create a gradient booster from given name
|
||||
* \param name name of gradient booster
|
||||
*/
|
||||
static GradientBooster* Create(const char *name);
|
||||
static GradientBooster* Create(const std::string& name);
|
||||
};
|
||||
|
||||
// implementing configure.
|
||||
template<typename PairIter>
|
||||
inline void GradientBooster::Configure(PairIter begin, PairIter end) {
|
||||
std::vector<std::pair<std::string, std::string> > vec(begin, end);
|
||||
this->Configure(vec);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Registry entry for tree updater.
|
||||
*/
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#include <vector>
|
||||
#include "./base.h"
|
||||
#include "./gbm.h"
|
||||
#include "./meric.h"
|
||||
#include "./metric.h"
|
||||
#include "./objective.h"
|
||||
|
||||
namespace xgboost {
|
||||
@@ -36,6 +36,14 @@ namespace xgboost {
|
||||
*/
|
||||
class Learner : public rabit::Serializable {
|
||||
public:
|
||||
/*!
|
||||
* \brief set configuration from pair iterators.
|
||||
* \param begin The beginning iterator.
|
||||
* \param end The end iterator.
|
||||
* \tparam PairIter iterator<std::pair<std::string, std::string> >
|
||||
*/
|
||||
template<typename PairIter>
|
||||
inline void Configure(PairIter begin, PairIter end);
|
||||
/*!
|
||||
* \brief Set the configuration of gradient boosting.
|
||||
* User must call configure once before InitModel and Training.
|
||||
@@ -59,7 +67,7 @@ class Learner : public rabit::Serializable {
|
||||
* \param iter current iteration number
|
||||
* \param train reference to the data matrix.
|
||||
*/
|
||||
void UpdateOneIter(int iter, DMatrix* train);
|
||||
virtual void UpdateOneIter(int iter, DMatrix* train) = 0;
|
||||
/*!
|
||||
* \brief Do customized gradient boosting with in_gpair.
|
||||
* in_gair can be mutated after this call.
|
||||
@@ -67,9 +75,9 @@ class Learner : public rabit::Serializable {
|
||||
* \param train reference to the data matrix.
|
||||
* \param in_gpair The input gradient statistics.
|
||||
*/
|
||||
void BoostOneIter(int iter,
|
||||
DMatrix* train,
|
||||
std::vector<bst_gpair>* in_gpair);
|
||||
virtual void BoostOneIter(int iter,
|
||||
DMatrix* train,
|
||||
std::vector<bst_gpair>* in_gpair) = 0;
|
||||
/*!
|
||||
* \brief evaluate the model for specific iteration using the configured metrics.
|
||||
* \param iter iteration number
|
||||
@@ -77,9 +85,9 @@ class Learner : public rabit::Serializable {
|
||||
* \param data_names name of each dataset
|
||||
* \return a string corresponding to the evaluation result
|
||||
*/
|
||||
std::string EvalOneIter(int iter,
|
||||
const std::vector<DMatrix*>& data_sets,
|
||||
const std::vector<std::string>& data_names);
|
||||
virtual std::string EvalOneIter(int iter,
|
||||
const std::vector<DMatrix*>& data_sets,
|
||||
const std::vector<std::string>& data_names) = 0;
|
||||
/*!
|
||||
* \brief get prediction given the model.
|
||||
* \param data input data
|
||||
@@ -89,11 +97,11 @@ class Learner : public rabit::Serializable {
|
||||
* predictor, when it equals 0, this means we are using all the trees
|
||||
* \param pred_leaf whether to only predict the leaf index of each tree in a boosted tree predictor
|
||||
*/
|
||||
void Predict(DMatrix* data,
|
||||
bool output_margin,
|
||||
std::vector<float> *out_preds,
|
||||
unsigned ntree_limit = 0,
|
||||
bool pred_leaf = false) const;
|
||||
virtual void Predict(DMatrix* data,
|
||||
bool output_margin,
|
||||
std::vector<float> *out_preds,
|
||||
unsigned ntree_limit = 0,
|
||||
bool pred_leaf = false) const = 0;
|
||||
/*!
|
||||
* \return whether the model allow lazy checkpoint in rabit.
|
||||
*/
|
||||
@@ -151,5 +159,13 @@ inline void Learner::Predict(const SparseBatch::Inst& inst,
|
||||
obj_->PredTransform(out_preds);
|
||||
}
|
||||
}
|
||||
|
||||
// implementing configure.
|
||||
template<typename PairIter>
|
||||
inline void Learner::Configure(PairIter begin, PairIter end) {
|
||||
std::vector<std::pair<std::string, std::string> > vec(begin, end);
|
||||
this->Configure(vec);
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_LEARNER_H_
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
|
||||
#include <dmlc/registry.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include "./data.h"
|
||||
#include "./base.h"
|
||||
@@ -42,7 +43,7 @@ class Metric {
|
||||
* and the name will be matched in the registry.
|
||||
* \return the created metric.
|
||||
*/
|
||||
static Metric* Create(const char *name);
|
||||
static Metric* Create(const std::string& name);
|
||||
};
|
||||
|
||||
/*!
|
||||
|
||||
@@ -22,10 +22,18 @@ class ObjFunction {
|
||||
/*! \brief virtual destructor */
|
||||
virtual ~ObjFunction() {}
|
||||
/*!
|
||||
* \brief Initialize the objective with the specified parameters.
|
||||
* \brief set configuration from pair iterators.
|
||||
* \param begin The beginning iterator.
|
||||
* \param end The end iterator.
|
||||
* \tparam PairIter iterator<std::pair<std::string, std::string> >
|
||||
*/
|
||||
template<typename PairIter>
|
||||
inline void Configure(PairIter begin, PairIter end);
|
||||
/*!
|
||||
* \brief Configure the objective with the specified parameters.
|
||||
* \param args arguments to the objective function.
|
||||
*/
|
||||
virtual void Init(const std::vector<std::pair<std::string, std::string> >& args) = 0;
|
||||
virtual void Configure(const std::vector<std::pair<std::string, std::string> >& args) = 0;
|
||||
/*!
|
||||
* \brief Get gradient over each of predictions, given existing information.
|
||||
* \param preds prediction of current round
|
||||
@@ -66,9 +74,16 @@ class ObjFunction {
|
||||
* \brief Create an objective function according to name.
|
||||
* \param name Name of the objective.
|
||||
*/
|
||||
static ObjFunction* Create(const char* name);
|
||||
static ObjFunction* Create(const std::string& name);
|
||||
};
|
||||
|
||||
// implementing configure.
|
||||
template<typename PairIter>
|
||||
inline void ObjFunction::Configure(PairIter begin, PairIter end) {
|
||||
std::vector<std::pair<std::string, std::string> > vec(begin, end);
|
||||
this->Configure(vec);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Registry entry for objective factory functions.
|
||||
*/
|
||||
|
||||
@@ -54,7 +54,7 @@ class TreeUpdater {
|
||||
* \brief Create a tree updater given name
|
||||
* \param name Name of the tree updater.
|
||||
*/
|
||||
static TreeUpdater* Create(const char* name);
|
||||
static TreeUpdater* Create(const std::string& name);
|
||||
};
|
||||
|
||||
/*!
|
||||
|
||||
Reference in New Issue
Block a user