[LEARNER] refactor learner

This commit is contained in:
tqchen
2016-01-04 01:31:44 -08:00
parent 4b4b36d047
commit 0d95e863c9
14 changed files with 470 additions and 517 deletions

View File

@@ -14,7 +14,7 @@
#include <vector>
#include "./base.h"
#include "./gbm.h"
#include "./meric.h"
#include "./metric.h"
#include "./objective.h"
namespace xgboost {
@@ -36,6 +36,14 @@ namespace xgboost {
*/
class Learner : public rabit::Serializable {
public:
/*!
* \brief set configuration from pair iterators.
* \param begin The beginning iterator.
* \param end The end iterator.
* \tparam PairIter iterator<std::pair<std::string, std::string> >
*/
template<typename PairIter>
inline void Configure(PairIter begin, PairIter end);
/*!
* \brief Set the configuration of gradient boosting.
* User must call configure once before InitModel and Training.
@@ -59,7 +67,7 @@ class Learner : public rabit::Serializable {
* \param iter current iteration number
* \param train reference to the data matrix.
*/
void UpdateOneIter(int iter, DMatrix* train);
virtual void UpdateOneIter(int iter, DMatrix* train) = 0;
/*!
* \brief Do customized gradient boosting with in_gpair.
* in_gair can be mutated after this call.
@@ -67,9 +75,9 @@ class Learner : public rabit::Serializable {
* \param train reference to the data matrix.
* \param in_gpair The input gradient statistics.
*/
void BoostOneIter(int iter,
DMatrix* train,
std::vector<bst_gpair>* in_gpair);
virtual void BoostOneIter(int iter,
DMatrix* train,
std::vector<bst_gpair>* in_gpair) = 0;
/*!
* \brief evaluate the model for specific iteration using the configured metrics.
* \param iter iteration number
@@ -77,9 +85,9 @@ class Learner : public rabit::Serializable {
* \param data_names name of each dataset
* \return a string corresponding to the evaluation result
*/
std::string EvalOneIter(int iter,
const std::vector<DMatrix*>& data_sets,
const std::vector<std::string>& data_names);
virtual std::string EvalOneIter(int iter,
const std::vector<DMatrix*>& data_sets,
const std::vector<std::string>& data_names) = 0;
/*!
* \brief get prediction given the model.
* \param data input data
@@ -89,11 +97,11 @@ class Learner : public rabit::Serializable {
* predictor, when it equals 0, this means we are using all the trees
* \param pred_leaf whether to only predict the leaf index of each tree in a boosted tree predictor
*/
void Predict(DMatrix* data,
bool output_margin,
std::vector<float> *out_preds,
unsigned ntree_limit = 0,
bool pred_leaf = false) const;
virtual void Predict(DMatrix* data,
bool output_margin,
std::vector<float> *out_preds,
unsigned ntree_limit = 0,
bool pred_leaf = false) const = 0;
/*!
* \return whether the model allow lazy checkpoint in rabit.
*/
@@ -151,5 +159,13 @@ inline void Learner::Predict(const SparseBatch::Inst& inst,
obj_->PredTransform(out_preds);
}
}
// implementing configure.
template<typename PairIter>
inline void Learner::Configure(PairIter begin, PairIter end) {
std::vector<std::pair<std::string, std::string> > vec(begin, end);
this->Configure(vec);
}
} // namespace xgboost
#endif // XGBOOST_LEARNER_H_