remake the wrapper
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
* used for regression/classification/ranking
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <vector>
|
||||
#include "../data.h"
|
||||
|
||||
namespace xgboost {
|
||||
@@ -43,7 +44,7 @@ struct MetaInfo {
|
||||
}
|
||||
/*! \brief get weight of each instances */
|
||||
inline float GetWeight(size_t i) const {
|
||||
if(weights.size() != 0) {
|
||||
if (weights.size() != 0) {
|
||||
return weights[i];
|
||||
} else {
|
||||
return 1.0f;
|
||||
@@ -51,7 +52,7 @@ struct MetaInfo {
|
||||
}
|
||||
/*! \brief get root index of i-th instance */
|
||||
inline float GetRoot(size_t i) const {
|
||||
if(root_index.size() != 0) {
|
||||
if (root_index.size() != 0) {
|
||||
return static_cast<float>(root_index[i]);
|
||||
} else {
|
||||
return 0;
|
||||
@@ -76,7 +77,7 @@ struct MetaInfo {
|
||||
// try to load group information from file, if exists
|
||||
inline bool TryLoadGroup(const char* fname, bool silent = false) {
|
||||
FILE *fi = fopen64(fname, "r");
|
||||
if (fi == NULL) return false;
|
||||
if (fi == NULL) return false;
|
||||
group_ptr.push_back(0);
|
||||
unsigned nline;
|
||||
while (fscanf(fi, "%u", &nline) == 1) {
|
||||
@@ -110,6 +111,11 @@ struct MetaInfo {
|
||||
*/
|
||||
template<typename FMatrix>
|
||||
struct DMatrix {
|
||||
/*!
|
||||
* \brief magic number associated with this object
|
||||
* used to check if it is specific instance
|
||||
*/
|
||||
const int magic;
|
||||
/*! \brief meta information about the dataset */
|
||||
MetaInfo info;
|
||||
/*! \brief feature matrix about data content */
|
||||
@@ -120,7 +126,7 @@ struct DMatrix {
|
||||
*/
|
||||
void *cache_learner_ptr_;
|
||||
/*! \brief default constructor */
|
||||
DMatrix(void) : cache_learner_ptr_(NULL) {}
|
||||
explicit DMatrix(int magic) : magic(magic), cache_learner_ptr_(NULL) {}
|
||||
// virtual destructor
|
||||
virtual ~DMatrix(void){}
|
||||
};
|
||||
|
||||
@@ -39,7 +39,7 @@ inline IEvaluator* CreateEvaluator(const char *name) {
|
||||
if (!strcmp(name, "merror")) return new EvalMatchError();
|
||||
if (!strcmp(name, "logloss")) return new EvalLogLoss();
|
||||
if (!strcmp(name, "auc")) return new EvalAuc();
|
||||
if (!strncmp(name, "ams@",4)) return new EvalAMS(name);
|
||||
if (!strncmp(name, "ams@", 4)) return new EvalAMS(name);
|
||||
if (!strncmp(name, "pre@", 4)) return new EvalPrecision(name);
|
||||
if (!strncmp(name, "map", 3)) return new EvalMAP(name);
|
||||
if (!strncmp(name, "ndcg", 3)) return new EvalNDCG(name);
|
||||
|
||||
@@ -78,6 +78,7 @@ class BoostLearner {
|
||||
inline void SetParam(const char *name, const char *val) {
|
||||
if (!strcmp(name, "silent")) silent = atoi(val);
|
||||
if (!strcmp(name, "eval_metric")) evaluator_.AddEval(val);
|
||||
if (!strcmp("seed", name)) random::Seed(atoi(val));
|
||||
if (gbm_ == NULL) {
|
||||
if (!strcmp(name, "objective")) name_obj_ = val;
|
||||
if (!strcmp(name, "booster")) name_gbm_ = val;
|
||||
@@ -132,16 +133,24 @@ class BoostLearner {
|
||||
utils::FileStream fo(utils::FopenCheck(fname, "wb"));
|
||||
this->SaveModel(fo);
|
||||
fo.Close();
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief check if data matrix is ready to be used by training,
|
||||
* if not intialize it
|
||||
* \param p_train pointer to the matrix used by training
|
||||
*/
|
||||
inline void CheckInit(DMatrix<FMatrix> *p_train) const {
|
||||
p_train->fmat.InitColAccess();
|
||||
}
|
||||
/*!
|
||||
* \brief update the model for one iteration
|
||||
* \param iter current iteration number
|
||||
* \param p_train pointer to the data matrix
|
||||
*/
|
||||
inline void UpdateOneIter(int iter, DMatrix<FMatrix> *p_train) {
|
||||
this->PredictRaw(*p_train, &preds_);
|
||||
obj_->GetGradient(preds_, p_train->info, iter, &gpair_);
|
||||
gbm_->DoBoost(gpair_, p_train->fmat, p_train->info.root_index);
|
||||
inline void UpdateOneIter(int iter, const DMatrix<FMatrix> &train) {
|
||||
this->PredictRaw(train, &preds_);
|
||||
obj_->GetGradient(preds_, train.info, iter, &gpair_);
|
||||
gbm_->DoBoost(gpair_, train.fmat, train.info.root_index);
|
||||
}
|
||||
/*!
|
||||
* \brief evaluate the model for specific iteration
|
||||
|
||||
Reference in New Issue
Block a user