finish refactor, need debug
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
#include "../data.h"
|
||||
#include "../utils/io.h"
|
||||
namespace xgboost {
|
||||
@@ -150,8 +151,6 @@ struct DMatrix {
|
||||
const int magic;
|
||||
/*! \brief meta information about the dataset */
|
||||
MetaInfo info;
|
||||
/*! \brief feature matrix about data content */
|
||||
IFMatrix *fmat;
|
||||
/*!
|
||||
* \brief cache pointer to verify if the data structure is cached in some learner
|
||||
* used to verify if DMatrix is cached
|
||||
@@ -159,10 +158,10 @@ struct DMatrix {
|
||||
void *cache_learner_ptr_;
|
||||
/*! \brief default constructor */
|
||||
explicit DMatrix(int magic) : magic(magic), cache_learner_ptr_(NULL) {}
|
||||
/*! \brief get feature matrix about data content */
|
||||
virtual IFMatrix *fmat(void) const = 0;
|
||||
// virtual destructor
|
||||
virtual ~DMatrix(void){
|
||||
delete fmat;
|
||||
}
|
||||
virtual ~DMatrix(void){}
|
||||
};
|
||||
|
||||
} // namespace learner
|
||||
|
||||
@@ -158,7 +158,7 @@ class BoostLearner {
|
||||
* \param p_train pointer to the matrix used by training
|
||||
*/
|
||||
inline void CheckInit(DMatrix *p_train) {
|
||||
p_train->fmat->InitColAccess(prob_buffer_row);
|
||||
p_train->fmat()->InitColAccess(prob_buffer_row);
|
||||
}
|
||||
/*!
|
||||
* \brief update the model for one iteration
|
||||
@@ -168,7 +168,7 @@ class BoostLearner {
|
||||
inline void UpdateOneIter(int iter, const DMatrix &train) {
|
||||
this->PredictRaw(train, &preds_);
|
||||
obj_->GetGradient(preds_, train.info, iter, &gpair_);
|
||||
gbm_->DoBoost(train.fmat, train.info.info, &gpair_);
|
||||
gbm_->DoBoost(train.fmat(), train.info.info, &gpair_);
|
||||
}
|
||||
/*!
|
||||
* \brief evaluate the model for specific iteration
|
||||
@@ -248,7 +248,7 @@ class BoostLearner {
|
||||
*/
|
||||
inline void PredictRaw(const DMatrix &data,
|
||||
std::vector<float> *out_preds) const {
|
||||
gbm_->Predict(data.fmat, this->FindBufferOffset(data),
|
||||
gbm_->Predict(data.fmat(), this->FindBufferOffset(data),
|
||||
data.info.info, out_preds);
|
||||
// add base margin
|
||||
std::vector<float> &preds = *out_preds;
|
||||
|
||||
Reference in New Issue
Block a user