check in io module

This commit is contained in:
tqchen
2014-08-16 14:06:31 -07:00
parent ac1cc15b90
commit c4acb4fe01
10 changed files with 417 additions and 33 deletions

View File

@@ -10,10 +10,14 @@
namespace xgboost {
namespace learner {
/*!
/*!
* \brief meta information needed in training, including label, weight
*/
struct MetaInfo {
/*! \brief number of rows in the data */
size_t num_row;
/*! \brief number of columns in the data */
size_t num_col;
/*! \brief label of each instance */
std::vector<float> labels;
/*!
@@ -28,6 +32,15 @@ struct MetaInfo {
* can be used for multi task setting
*/
std::vector<unsigned> root_index;
MetaInfo(void) : num_row(0), num_col(0) {}
/*! \brief clear all the information */
inline void Clear(void) {
labels.clear();
group_ptr.clear();
weights.clear();
root_index.clear();
num_row = num_col = 0;
}
/*! \brief get weight of each instances */
inline float GetWeight(size_t i) const {
if(weights.size() != 0) {
@@ -45,20 +58,53 @@ struct MetaInfo {
}
}
inline void SaveBinary(utils::IStream &fo) {
fo.Write(&num_row, sizeof(num_row));
fo.Write(&num_col, sizeof(num_col));
fo.Write(labels);
fo.Write(group_ptr);
fo.Write(weights);
fo.Write(root_index);
}
inline void LoadBinary(utils::IStream &fi) {
utils::Check(fi.Read(&num_row, sizeof(num_row)), "MetaInfo: invalid format");
utils::Check(fi.Read(&num_col, sizeof(num_col)), "MetaInfo: invalid format");
utils::Check(fi.Read(&labels), "MetaInfo: invalid format");
utils::Check(fi.Read(&group_ptr), "MetaInfo: invalid format");
utils::Check(fi.Read(&weights), "MetaInfo: invalid format");
utils::Check(fi.Read(&root_index), "MetaInfo: invalid format");
}
// try to load group information from file, if exists
inline bool TryLoadGroup(const char* fname, bool silent = false) {
FILE *fi = fopen64(fname, "r");
if (fi == NULL) return false;
group_ptr.push_back(0);
unsigned nline;
while (fscanf(fi, "%u", &nline) == 1) {
group_ptr.push_back(group_ptr.back()+nline);
}
if (!silent) {
printf("%lu groups are loaded from %s\n", group_ptr.size()-1, fname);
}
fclose(fi);
return true;
}
// try to load weight information from file, if exists
inline bool TryLoadWeight(const char* fname, bool silent = false) {
FILE *fi = fopen64(fname, "r");
if (fi == NULL) return false;
float wt;
while (fscanf(fi, "%f", &wt) == 1) {
weights.push_back(wt);
}
if (!silent) {
printf("loading weight from %s\n", fname);
}
fclose(fi);
return true;
}
};
/*!
/*!
* \brief data object used for learning,
* \tparam FMatrix type of feature data source
*/
@@ -66,8 +112,6 @@ template<typename FMatrix>
struct DMatrix {
/*! \brief meta information about the dataset */
MetaInfo info;
/*! \brief number of rows in the DMatrix */
size_t num_row;
/*! \brief feature matrix about data content */
FMatrix fmat;
/*!
@@ -77,6 +121,8 @@ struct DMatrix {
void *cache_learner_ptr_;
/*! \brief default constructor */
DMatrix(void) : cache_learner_ptr_(NULL) {}
// virtual destructor
virtual ~DMatrix(void){}
};
} // namespace learner

View File

@@ -55,9 +55,9 @@ class BoostLearner {
if (dupilicate) continue;
// set mats[i]'s cache learner pointer to this
mats[i]->cache_learner_ptr_ = this;
cache_.push_back(CacheEntry(mats[i], buffer_size, mats[i]->num_row));
buffer_size += mats[i]->num_row;
num_feature = std::max(num_feature, static_cast<unsigned>(mats[i]->num_col));
cache_.push_back(CacheEntry(mats[i], buffer_size, mats[i]->info.num_row));
buffer_size += mats[i]->info.num_row;
num_feature = std::max(num_feature, static_cast<unsigned>(mats[i]->info.num_col));
}
char str_temp[25];
if (num_feature > mparam.num_feature) {