check in io module
This commit is contained in:
@@ -10,10 +10,14 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace learner {
|
||||
/*!
|
||||
/*!
|
||||
* \brief meta information needed in training, including label, weight
|
||||
*/
|
||||
struct MetaInfo {
|
||||
/*! \brief number of rows in the data */
|
||||
size_t num_row;
|
||||
/*! \brief number of columns in the data */
|
||||
size_t num_col;
|
||||
/*! \brief label of each instance */
|
||||
std::vector<float> labels;
|
||||
/*!
|
||||
@@ -28,6 +32,15 @@ struct MetaInfo {
|
||||
* can be used for multi task setting
|
||||
*/
|
||||
std::vector<unsigned> root_index;
|
||||
MetaInfo(void) : num_row(0), num_col(0) {}
|
||||
/*! \brief clear all the information */
|
||||
inline void Clear(void) {
|
||||
labels.clear();
|
||||
group_ptr.clear();
|
||||
weights.clear();
|
||||
root_index.clear();
|
||||
num_row = num_col = 0;
|
||||
}
|
||||
/*! \brief get weight of each instances */
|
||||
inline float GetWeight(size_t i) const {
|
||||
if(weights.size() != 0) {
|
||||
@@ -45,20 +58,53 @@ struct MetaInfo {
|
||||
}
|
||||
}
|
||||
inline void SaveBinary(utils::IStream &fo) {
|
||||
fo.Write(&num_row, sizeof(num_row));
|
||||
fo.Write(&num_col, sizeof(num_col));
|
||||
fo.Write(labels);
|
||||
fo.Write(group_ptr);
|
||||
fo.Write(weights);
|
||||
fo.Write(root_index);
|
||||
}
|
||||
inline void LoadBinary(utils::IStream &fi) {
|
||||
utils::Check(fi.Read(&num_row, sizeof(num_row)), "MetaInfo: invalid format");
|
||||
utils::Check(fi.Read(&num_col, sizeof(num_col)), "MetaInfo: invalid format");
|
||||
utils::Check(fi.Read(&labels), "MetaInfo: invalid format");
|
||||
utils::Check(fi.Read(&group_ptr), "MetaInfo: invalid format");
|
||||
utils::Check(fi.Read(&weights), "MetaInfo: invalid format");
|
||||
utils::Check(fi.Read(&root_index), "MetaInfo: invalid format");
|
||||
}
|
||||
// try to load group information from file, if exists
|
||||
inline bool TryLoadGroup(const char* fname, bool silent = false) {
|
||||
FILE *fi = fopen64(fname, "r");
|
||||
if (fi == NULL) return false;
|
||||
group_ptr.push_back(0);
|
||||
unsigned nline;
|
||||
while (fscanf(fi, "%u", &nline) == 1) {
|
||||
group_ptr.push_back(group_ptr.back()+nline);
|
||||
}
|
||||
if (!silent) {
|
||||
printf("%lu groups are loaded from %s\n", group_ptr.size()-1, fname);
|
||||
}
|
||||
fclose(fi);
|
||||
return true;
|
||||
}
|
||||
// try to load weight information from file, if exists
|
||||
inline bool TryLoadWeight(const char* fname, bool silent = false) {
|
||||
FILE *fi = fopen64(fname, "r");
|
||||
if (fi == NULL) return false;
|
||||
float wt;
|
||||
while (fscanf(fi, "%f", &wt) == 1) {
|
||||
weights.push_back(wt);
|
||||
}
|
||||
if (!silent) {
|
||||
printf("loading weight from %s\n", fname);
|
||||
}
|
||||
fclose(fi);
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
/*!
|
||||
/*!
|
||||
* \brief data object used for learning,
|
||||
* \tparam FMatrix type of feature data source
|
||||
*/
|
||||
@@ -66,8 +112,6 @@ template<typename FMatrix>
|
||||
struct DMatrix {
|
||||
/*! \brief meta information about the dataset */
|
||||
MetaInfo info;
|
||||
/*! \brief number of rows in the DMatrix */
|
||||
size_t num_row;
|
||||
/*! \brief feature matrix about data content */
|
||||
FMatrix fmat;
|
||||
/*!
|
||||
@@ -77,6 +121,8 @@ struct DMatrix {
|
||||
void *cache_learner_ptr_;
|
||||
/*! \brief default constructor */
|
||||
DMatrix(void) : cache_learner_ptr_(NULL) {}
|
||||
// virtual destructor
|
||||
virtual ~DMatrix(void){}
|
||||
};
|
||||
|
||||
} // namespace learner
|
||||
|
||||
@@ -55,9 +55,9 @@ class BoostLearner {
|
||||
if (dupilicate) continue;
|
||||
// set mats[i]'s cache learner pointer to this
|
||||
mats[i]->cache_learner_ptr_ = this;
|
||||
cache_.push_back(CacheEntry(mats[i], buffer_size, mats[i]->num_row));
|
||||
buffer_size += mats[i]->num_row;
|
||||
num_feature = std::max(num_feature, static_cast<unsigned>(mats[i]->num_col));
|
||||
cache_.push_back(CacheEntry(mats[i], buffer_size, mats[i]->info.num_row));
|
||||
buffer_size += mats[i]->info.num_row;
|
||||
num_feature = std::max(num_feature, static_cast<unsigned>(mats[i]->info.num_col));
|
||||
}
|
||||
char str_temp[25];
|
||||
if (num_feature > mparam.num_feature) {
|
||||
|
||||
Reference in New Issue
Block a user