[LEARNER] refactor learner

This commit is contained in:
tqchen
2016-01-04 01:31:44 -08:00
parent 4b4b36d047
commit 0d95e863c9
14 changed files with 470 additions and 517 deletions

View File

@@ -14,6 +14,9 @@
#include "./base.h"
namespace xgboost {
// forward declare learner.
class LearnerImpl;
/*! \brief data type accepted by xgboost interface */
enum DataType {
kFloat32 = 1,
@@ -199,6 +202,8 @@ class DataSource : public dmlc::DataIter<RowBatch> {
*/
class DMatrix {
public:
/*! \brief default constructor */
DMatrix() : cache_learner_ptr_(nullptr) {}
/*! \brief meta information of the dataset */
virtual MetaInfo& info() = 0;
/*! \brief meta information of the dataset */
@@ -222,6 +227,7 @@ class DMatrix {
* \param subsample subsample ratio when generating column access.
* \param max_row_perbatch auxilary information, maximum row used in each column batch.
* this is a hint information that can be ignored by the implementation.
* \return Number of column blocks in the column access.
*/
virtual void InitColAccess(const std::vector<bool>& enabled,
float subsample,
@@ -229,6 +235,8 @@ class DMatrix {
// the following are column meta data, should be able to answer them fast.
/*! \return whether column access is enabled */
virtual bool HaveColAccess() const = 0;
/*! \return Whether the data columns single column block. */
virtual bool SingleColBlock() const = 0;
/*! \brief get number of non-missing entries in column */
virtual size_t GetColSize(size_t cidx) const = 0;
/*! \brief get column density */
@@ -279,6 +287,12 @@ class DMatrix {
*/
static DMatrix* Create(dmlc::Parser<uint32_t>* parser,
const char* cache_prefix = nullptr);
private:
// allow learner class to access this field.
friend class LearnerImpl;
/*! \brief public field to back ref cached matrix. */
LearnerImpl* cache_learner_ptr_;
};
} // namespace xgboost

View File

@@ -25,6 +25,14 @@ class GradientBooster {
public:
/*! \brief virtual destructor */
virtual ~GradientBooster() {}
/*!
* \brief set configuration from pair iterators.
* \param begin The beginning iterator.
* \param end The end iterator.
* \tparam PairIter iterator<std::pair<std::string, std::string> >
*/
template<typename PairIter>
inline void Configure(PairIter begin, PairIter end);
/*!
* \brief Set the configuration of gradient boosting.
* User must call configure once before InitModel and Training.
@@ -123,9 +131,16 @@ class GradientBooster {
* \breif create a gradient booster from given name
* \param name name of gradient booster
*/
static GradientBooster* Create(const char *name);
static GradientBooster* Create(const std::string& name);
};
// implementing configure.
template<typename PairIter>
inline void GradientBooster::Configure(PairIter begin, PairIter end) {
std::vector<std::pair<std::string, std::string> > vec(begin, end);
this->Configure(vec);
}
/*!
* \brief Registry entry for tree updater.
*/

View File

@@ -14,7 +14,7 @@
#include <vector>
#include "./base.h"
#include "./gbm.h"
#include "./meric.h"
#include "./metric.h"
#include "./objective.h"
namespace xgboost {
@@ -36,6 +36,14 @@ namespace xgboost {
*/
class Learner : public rabit::Serializable {
public:
/*!
* \brief set configuration from pair iterators.
* \param begin The beginning iterator.
* \param end The end iterator.
* \tparam PairIter iterator<std::pair<std::string, std::string> >
*/
template<typename PairIter>
inline void Configure(PairIter begin, PairIter end);
/*!
* \brief Set the configuration of gradient boosting.
* User must call configure once before InitModel and Training.
@@ -59,7 +67,7 @@ class Learner : public rabit::Serializable {
* \param iter current iteration number
* \param train reference to the data matrix.
*/
void UpdateOneIter(int iter, DMatrix* train);
virtual void UpdateOneIter(int iter, DMatrix* train) = 0;
/*!
* \brief Do customized gradient boosting with in_gpair.
* in_gair can be mutated after this call.
@@ -67,9 +75,9 @@ class Learner : public rabit::Serializable {
* \param train reference to the data matrix.
* \param in_gpair The input gradient statistics.
*/
void BoostOneIter(int iter,
DMatrix* train,
std::vector<bst_gpair>* in_gpair);
virtual void BoostOneIter(int iter,
DMatrix* train,
std::vector<bst_gpair>* in_gpair) = 0;
/*!
* \brief evaluate the model for specific iteration using the configured metrics.
* \param iter iteration number
@@ -77,9 +85,9 @@ class Learner : public rabit::Serializable {
* \param data_names name of each dataset
* \return a string corresponding to the evaluation result
*/
std::string EvalOneIter(int iter,
const std::vector<DMatrix*>& data_sets,
const std::vector<std::string>& data_names);
virtual std::string EvalOneIter(int iter,
const std::vector<DMatrix*>& data_sets,
const std::vector<std::string>& data_names) = 0;
/*!
* \brief get prediction given the model.
* \param data input data
@@ -89,11 +97,11 @@ class Learner : public rabit::Serializable {
* predictor, when it equals 0, this means we are using all the trees
* \param pred_leaf whether to only predict the leaf index of each tree in a boosted tree predictor
*/
void Predict(DMatrix* data,
bool output_margin,
std::vector<float> *out_preds,
unsigned ntree_limit = 0,
bool pred_leaf = false) const;
virtual void Predict(DMatrix* data,
bool output_margin,
std::vector<float> *out_preds,
unsigned ntree_limit = 0,
bool pred_leaf = false) const = 0;
/*!
* \return whether the model allow lazy checkpoint in rabit.
*/
@@ -151,5 +159,13 @@ inline void Learner::Predict(const SparseBatch::Inst& inst,
obj_->PredTransform(out_preds);
}
}
// implementing configure.
template<typename PairIter>
inline void Learner::Configure(PairIter begin, PairIter end) {
std::vector<std::pair<std::string, std::string> > vec(begin, end);
this->Configure(vec);
}
} // namespace xgboost
#endif // XGBOOST_LEARNER_H_

View File

@@ -9,6 +9,7 @@
#include <dmlc/registry.h>
#include <vector>
#include <string>
#include <functional>
#include "./data.h"
#include "./base.h"
@@ -42,7 +43,7 @@ class Metric {
* and the name will be matched in the registry.
* \return the created metric.
*/
static Metric* Create(const char *name);
static Metric* Create(const std::string& name);
};
/*!

View File

@@ -22,10 +22,18 @@ class ObjFunction {
/*! \brief virtual destructor */
virtual ~ObjFunction() {}
/*!
* \brief Initialize the objective with the specified parameters.
* \brief set configuration from pair iterators.
* \param begin The beginning iterator.
* \param end The end iterator.
* \tparam PairIter iterator<std::pair<std::string, std::string> >
*/
template<typename PairIter>
inline void Configure(PairIter begin, PairIter end);
/*!
* \brief Configure the objective with the specified parameters.
* \param args arguments to the objective function.
*/
virtual void Init(const std::vector<std::pair<std::string, std::string> >& args) = 0;
virtual void Configure(const std::vector<std::pair<std::string, std::string> >& args) = 0;
/*!
* \brief Get gradient over each of predictions, given existing information.
* \param preds prediction of current round
@@ -66,9 +74,16 @@ class ObjFunction {
* \brief Create an objective function according to name.
* \param name Name of the objective.
*/
static ObjFunction* Create(const char* name);
static ObjFunction* Create(const std::string& name);
};
// implementing configure.
template<typename PairIter>
inline void ObjFunction::Configure(PairIter begin, PairIter end) {
std::vector<std::pair<std::string, std::string> > vec(begin, end);
this->Configure(vec);
}
/*!
* \brief Registry entry for objective factory functions.
*/

View File

@@ -54,7 +54,7 @@ class TreeUpdater {
* \brief Create a tree updater given name
* \param name Name of the tree updater.
*/
static TreeUpdater* Create(const char* name);
static TreeUpdater* Create(const std::string& name);
};
/*!