xgboost/learner/dmatrix.h
2014-08-15 20:15:58 -07:00

85 lines
2.3 KiB
C++

#ifndef XGBOOST_LEARNER_DMATRIX_H_
#define XGBOOST_LEARNER_DMATRIX_H_
/*!
* \file dmatrix.h
* \brief meta data and template data structure
* used for regression/classification/ranking
* \author Tianqi Chen
*/
#include "../data.h"
namespace xgboost {
namespace learner {
/*!
* \brief meta information needed in training, including label, weight
*/
struct MetaInfo {
/*! \brief label of each instance */
std::vector<float> labels;
/*!
* \brief the index of begin and end of a group
* needed when the learning task is ranking
*/
std::vector<bst_uint> group_ptr;
/*! \brief weights of each instance, optional */
std::vector<float> weights;
/*!
* \brief specified root index of each instance,
* can be used for multi task setting
*/
std::vector<unsigned> root_index;
/*! \brief get weight of each instances */
inline float GetWeight(size_t i) const {
if(weights.size() != 0) {
return weights[i];
} else {
return 1.0f;
}
}
/*! \brief get root index of i-th instance */
inline float GetRoot(size_t i) const {
if(root_index.size() != 0) {
return static_cast<float>(root_index[i]);
} else {
return 0;
}
}
inline void SaveBinary(utils::IStream &fo) {
fo.Write(labels);
fo.Write(group_ptr);
fo.Write(weights);
fo.Write(root_index);
}
inline void LoadBinary(utils::IStream &fi) {
utils::Check(fi.Read(&labels), "MetaInfo: invalid format");
utils::Check(fi.Read(&group_ptr), "MetaInfo: invalid format");
utils::Check(fi.Read(&weights), "MetaInfo: invalid format");
utils::Check(fi.Read(&root_index), "MetaInfo: invalid format");
}
};
/*!
* \brief data object used for learning,
* \tparam FMatrix type of feature data source
*/
template<typename FMatrix>
struct DMatrix {
/*! \brief meta information about the dataset */
MetaInfo info;
/*! \brief number of rows in the DMatrix */
size_t num_row;
/*! \brief feature matrix about data content */
FMatrix fmat;
/*!
* \brief cache pointer to verify if the data structure is cached in some learner
* used to verify if DMatrix is cached
*/
void *cache_learner_ptr_;
/*! \brief default constructor */
DMatrix(void) : cache_learner_ptr_(NULL) {}
};
} // namespace learner
} // namespace xgboost
#endif // XGBOOST_LEARNER_DMATRIX_H_