[LEARNER] refactor learner
This commit is contained in:
@@ -14,7 +14,7 @@
|
||||
#include <vector>
|
||||
#include "./base.h"
|
||||
#include "./gbm.h"
|
||||
#include "./meric.h"
|
||||
#include "./metric.h"
|
||||
#include "./objective.h"
|
||||
|
||||
namespace xgboost {
|
||||
@@ -36,6 +36,14 @@ namespace xgboost {
|
||||
*/
|
||||
class Learner : public rabit::Serializable {
|
||||
public:
|
||||
/*!
|
||||
* \brief set configuration from pair iterators.
|
||||
* \param begin The beginning iterator.
|
||||
* \param end The end iterator.
|
||||
* \tparam PairIter iterator<std::pair<std::string, std::string> >
|
||||
*/
|
||||
template<typename PairIter>
|
||||
inline void Configure(PairIter begin, PairIter end);
|
||||
/*!
|
||||
* \brief Set the configuration of gradient boosting.
|
||||
* User must call configure once before InitModel and Training.
|
||||
@@ -59,7 +67,7 @@ class Learner : public rabit::Serializable {
|
||||
* \param iter current iteration number
|
||||
* \param train reference to the data matrix.
|
||||
*/
|
||||
void UpdateOneIter(int iter, DMatrix* train);
|
||||
virtual void UpdateOneIter(int iter, DMatrix* train) = 0;
|
||||
/*!
|
||||
* \brief Do customized gradient boosting with in_gpair.
|
||||
* in_gair can be mutated after this call.
|
||||
@@ -67,9 +75,9 @@ class Learner : public rabit::Serializable {
|
||||
* \param train reference to the data matrix.
|
||||
* \param in_gpair The input gradient statistics.
|
||||
*/
|
||||
void BoostOneIter(int iter,
|
||||
DMatrix* train,
|
||||
std::vector<bst_gpair>* in_gpair);
|
||||
virtual void BoostOneIter(int iter,
|
||||
DMatrix* train,
|
||||
std::vector<bst_gpair>* in_gpair) = 0;
|
||||
/*!
|
||||
* \brief evaluate the model for specific iteration using the configured metrics.
|
||||
* \param iter iteration number
|
||||
@@ -77,9 +85,9 @@ class Learner : public rabit::Serializable {
|
||||
* \param data_names name of each dataset
|
||||
* \return a string corresponding to the evaluation result
|
||||
*/
|
||||
std::string EvalOneIter(int iter,
|
||||
const std::vector<DMatrix*>& data_sets,
|
||||
const std::vector<std::string>& data_names);
|
||||
virtual std::string EvalOneIter(int iter,
|
||||
const std::vector<DMatrix*>& data_sets,
|
||||
const std::vector<std::string>& data_names) = 0;
|
||||
/*!
|
||||
* \brief get prediction given the model.
|
||||
* \param data input data
|
||||
@@ -89,11 +97,11 @@ class Learner : public rabit::Serializable {
|
||||
* predictor, when it equals 0, this means we are using all the trees
|
||||
* \param pred_leaf whether to only predict the leaf index of each tree in a boosted tree predictor
|
||||
*/
|
||||
void Predict(DMatrix* data,
|
||||
bool output_margin,
|
||||
std::vector<float> *out_preds,
|
||||
unsigned ntree_limit = 0,
|
||||
bool pred_leaf = false) const;
|
||||
virtual void Predict(DMatrix* data,
|
||||
bool output_margin,
|
||||
std::vector<float> *out_preds,
|
||||
unsigned ntree_limit = 0,
|
||||
bool pred_leaf = false) const = 0;
|
||||
/*!
|
||||
* \return whether the model allow lazy checkpoint in rabit.
|
||||
*/
|
||||
@@ -151,5 +159,13 @@ inline void Learner::Predict(const SparseBatch::Inst& inst,
|
||||
obj_->PredTransform(out_preds);
|
||||
}
|
||||
}
|
||||
|
||||
// implementing configure.
|
||||
template<typename PairIter>
|
||||
inline void Learner::Configure(PairIter begin, PairIter end) {
|
||||
std::vector<std::pair<std::string, std::string> > vec(begin, end);
|
||||
this->Configure(vec);
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_LEARNER_H_
|
||||
|
||||
Reference in New Issue
Block a user