Merge pull request #43 from tqchen/unity
add changes that are not commited
This commit is contained in:
commit
4f0b0d2c88
@ -44,6 +44,10 @@ struct bst_gpair {
|
|||||||
* these information are not necessarily presented, and can be empty
|
* these information are not necessarily presented, and can be empty
|
||||||
*/
|
*/
|
||||||
struct BoosterInfo {
|
struct BoosterInfo {
|
||||||
|
/*! \brief number of rows in the data */
|
||||||
|
size_t num_row;
|
||||||
|
/*! \brief number of columns in the data */
|
||||||
|
size_t num_col;
|
||||||
/*!
|
/*!
|
||||||
* \brief specified root index of each instance,
|
* \brief specified root index of each instance,
|
||||||
* can be used for multi task setting
|
* can be used for multi task setting
|
||||||
@ -51,6 +55,9 @@ struct BoosterInfo {
|
|||||||
std::vector<unsigned> root_index;
|
std::vector<unsigned> root_index;
|
||||||
/*! \brief set fold indicator */
|
/*! \brief set fold indicator */
|
||||||
std::vector<unsigned> fold_index;
|
std::vector<unsigned> fold_index;
|
||||||
|
/*! \brief number of rows, number of columns */
|
||||||
|
BoosterInfo(void) : num_row(0), num_col(0) {
|
||||||
|
}
|
||||||
/*! \brief get root of ith instance */
|
/*! \brief get root of ith instance */
|
||||||
inline unsigned GetRoot(size_t i) const {
|
inline unsigned GetRoot(size_t i) const {
|
||||||
return root_index.size() == 0 ? 0 : root_index[i];
|
return root_index.size() == 0 ? 0 : root_index[i];
|
||||||
|
|||||||
@ -135,7 +135,7 @@ class GBTree : public IGradBooster<FMatrix> {
|
|||||||
const int tid = omp_get_thread_num();
|
const int tid = omp_get_thread_num();
|
||||||
tree::RegTree::FVec &feats = thread_temp[tid];
|
tree::RegTree::FVec &feats = thread_temp[tid];
|
||||||
const size_t ridx = batch.base_rowid + i;
|
const size_t ridx = batch.base_rowid + i;
|
||||||
const unsigned root_idx = info.GetRoot(i);
|
const unsigned root_idx = info.GetRoot(ridx);
|
||||||
// loop over output groups
|
// loop over output groups
|
||||||
for (int gid = 0; gid < mparam.num_output_group; ++gid) {
|
for (int gid = 0; gid < mparam.num_output_group; ++gid) {
|
||||||
preds[ridx * mparam.num_output_group + gid] =
|
preds[ridx * mparam.num_output_group + gid] =
|
||||||
|
|||||||
@ -62,10 +62,10 @@ class DMatrixSimple : public DataMatrix {
|
|||||||
inline size_t AddRow(const std::vector<SparseBatch::Entry> &feats) {
|
inline size_t AddRow(const std::vector<SparseBatch::Entry> &feats) {
|
||||||
for (size_t i = 0; i < feats.size(); ++i) {
|
for (size_t i = 0; i < feats.size(); ++i) {
|
||||||
row_data_.push_back(feats[i]);
|
row_data_.push_back(feats[i]);
|
||||||
info.num_col = std::max(info.num_col, static_cast<size_t>(feats[i].findex+1));
|
info.info.num_col = std::max(info.info.num_col, static_cast<size_t>(feats[i].findex+1));
|
||||||
}
|
}
|
||||||
row_ptr_.push_back(row_ptr_.back() + feats.size());
|
row_ptr_.push_back(row_ptr_.back() + feats.size());
|
||||||
info.num_row += 1;
|
info.info.num_row += 1;
|
||||||
return row_ptr_.size() - 2;
|
return row_ptr_.size() - 2;
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
@ -99,19 +99,19 @@ class DMatrixSimple : public DataMatrix {
|
|||||||
|
|
||||||
if (!silent) {
|
if (!silent) {
|
||||||
printf("%lux%lu matrix with %lu entries is loaded from %s\n",
|
printf("%lux%lu matrix with %lu entries is loaded from %s\n",
|
||||||
info.num_row, info.num_col, row_data_.size(), fname);
|
info.num_row(), info.num_col(), row_data_.size(), fname);
|
||||||
}
|
}
|
||||||
fclose(file);
|
fclose(file);
|
||||||
// try to load in additional file
|
// try to load in additional file
|
||||||
std::string name = fname;
|
std::string name = fname;
|
||||||
std::string gname = name + ".group";
|
std::string gname = name + ".group";
|
||||||
if (info.TryLoadGroup(gname.c_str(), silent)) {
|
if (info.TryLoadGroup(gname.c_str(), silent)) {
|
||||||
utils::Check(info.group_ptr.back() == info.num_row,
|
utils::Check(info.group_ptr.back() == info.num_row(),
|
||||||
"DMatrix: group data does not match the number of rows in features");
|
"DMatrix: group data does not match the number of rows in features");
|
||||||
}
|
}
|
||||||
std::string wname = name + ".weight";
|
std::string wname = name + ".weight";
|
||||||
if (info.TryLoadFloatInfo("weight", wname.c_str(), silent)) {
|
if (info.TryLoadFloatInfo("weight", wname.c_str(), silent)) {
|
||||||
utils::Check(info.weights.size() == info.num_row,
|
utils::Check(info.weights.size() == info.num_row(),
|
||||||
"DMatrix: weight data does not match the number of rows in features");
|
"DMatrix: weight data does not match the number of rows in features");
|
||||||
}
|
}
|
||||||
std::string mname = name + ".base_margin";
|
std::string mname = name + ".base_margin";
|
||||||
@ -139,7 +139,7 @@ class DMatrixSimple : public DataMatrix {
|
|||||||
|
|
||||||
if (!silent) {
|
if (!silent) {
|
||||||
printf("%lux%lu matrix with %lu entries is loaded from %s\n",
|
printf("%lux%lu matrix with %lu entries is loaded from %s\n",
|
||||||
info.num_row, info.num_col, row_data_.size(), fname);
|
info.num_row(), info.num_col(), row_data_.size(), fname);
|
||||||
if (info.group_ptr.size() != 0) {
|
if (info.group_ptr.size() != 0) {
|
||||||
printf("data contains %u groups\n", (unsigned)info.group_ptr.size()-1);
|
printf("data contains %u groups\n", (unsigned)info.group_ptr.size()-1);
|
||||||
}
|
}
|
||||||
@ -163,7 +163,7 @@ class DMatrixSimple : public DataMatrix {
|
|||||||
|
|
||||||
if (!silent) {
|
if (!silent) {
|
||||||
printf("%lux%lu matrix with %lu entries is saved to %s\n",
|
printf("%lux%lu matrix with %lu entries is saved to %s\n",
|
||||||
info.num_row, info.num_col, row_data_.size(), fname);
|
info.num_row(), info.num_col(), row_data_.size(), fname);
|
||||||
if (info.group_ptr.size() != 0) {
|
if (info.group_ptr.size() != 0) {
|
||||||
printf("data contains %lu groups\n", info.group_ptr.size()-1);
|
printf("data contains %lu groups\n", info.group_ptr.size()-1);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -15,10 +15,12 @@ namespace learner {
|
|||||||
* \brief meta information needed in training, including label, weight
|
* \brief meta information needed in training, including label, weight
|
||||||
*/
|
*/
|
||||||
struct MetaInfo {
|
struct MetaInfo {
|
||||||
/*! \brief number of rows in the data */
|
/*!
|
||||||
size_t num_row;
|
* \brief information needed by booster
|
||||||
/*! \brief number of columns in the data */
|
* BoosterInfo does not implement save and load,
|
||||||
size_t num_col;
|
* all serialization is done in MetaInfo
|
||||||
|
*/
|
||||||
|
BoosterInfo info;
|
||||||
/*! \brief label of each instance */
|
/*! \brief label of each instance */
|
||||||
std::vector<float> labels;
|
std::vector<float> labels;
|
||||||
/*!
|
/*!
|
||||||
@ -28,8 +30,6 @@ struct MetaInfo {
|
|||||||
std::vector<bst_uint> group_ptr;
|
std::vector<bst_uint> group_ptr;
|
||||||
/*! \brief weights of each instance, optional */
|
/*! \brief weights of each instance, optional */
|
||||||
std::vector<float> weights;
|
std::vector<float> weights;
|
||||||
/*! \brief information needed by booster */
|
|
||||||
BoosterInfo info;
|
|
||||||
/*!
|
/*!
|
||||||
* \brief initialized margins,
|
* \brief initialized margins,
|
||||||
* if specified, xgboost will start from this init margin
|
* if specified, xgboost will start from this init margin
|
||||||
@ -39,7 +39,15 @@ struct MetaInfo {
|
|||||||
/*! \brief version flag, used to check version of this info */
|
/*! \brief version flag, used to check version of this info */
|
||||||
static const int kVersion = 0;
|
static const int kVersion = 0;
|
||||||
// constructor
|
// constructor
|
||||||
MetaInfo(void) : num_row(0), num_col(0) {}
|
MetaInfo(void) {}
|
||||||
|
/*! \return number of rows in dataset */
|
||||||
|
inline size_t num_row(void) const {
|
||||||
|
return info.num_row;
|
||||||
|
}
|
||||||
|
/*! \return number of columns in dataset */
|
||||||
|
inline size_t num_col(void) const {
|
||||||
|
return info.num_col;
|
||||||
|
}
|
||||||
/*! \brief clear all the information */
|
/*! \brief clear all the information */
|
||||||
inline void Clear(void) {
|
inline void Clear(void) {
|
||||||
labels.clear();
|
labels.clear();
|
||||||
@ -47,7 +55,7 @@ struct MetaInfo {
|
|||||||
weights.clear();
|
weights.clear();
|
||||||
info.root_index.clear();
|
info.root_index.clear();
|
||||||
base_margin.clear();
|
base_margin.clear();
|
||||||
num_row = num_col = 0;
|
info.num_row = info.num_col = 0;
|
||||||
}
|
}
|
||||||
/*! \brief get weight of each instances */
|
/*! \brief get weight of each instances */
|
||||||
inline float GetWeight(size_t i) const {
|
inline float GetWeight(size_t i) const {
|
||||||
@ -60,8 +68,8 @@ struct MetaInfo {
|
|||||||
inline void SaveBinary(utils::IStream &fo) const {
|
inline void SaveBinary(utils::IStream &fo) const {
|
||||||
int version = kVersion;
|
int version = kVersion;
|
||||||
fo.Write(&version, sizeof(version));
|
fo.Write(&version, sizeof(version));
|
||||||
fo.Write(&num_row, sizeof(num_row));
|
fo.Write(&info.num_row, sizeof(info.num_row));
|
||||||
fo.Write(&num_col, sizeof(num_col));
|
fo.Write(&info.num_col, sizeof(info.num_col));
|
||||||
fo.Write(labels);
|
fo.Write(labels);
|
||||||
fo.Write(group_ptr);
|
fo.Write(group_ptr);
|
||||||
fo.Write(weights);
|
fo.Write(weights);
|
||||||
@ -71,8 +79,8 @@ struct MetaInfo {
|
|||||||
inline void LoadBinary(utils::IStream &fi) {
|
inline void LoadBinary(utils::IStream &fi) {
|
||||||
int version;
|
int version;
|
||||||
utils::Check(fi.Read(&version, sizeof(version)), "MetaInfo: invalid format");
|
utils::Check(fi.Read(&version, sizeof(version)), "MetaInfo: invalid format");
|
||||||
utils::Check(fi.Read(&num_row, sizeof(num_row)), "MetaInfo: invalid format");
|
utils::Check(fi.Read(&info.num_row, sizeof(info.num_row)), "MetaInfo: invalid format");
|
||||||
utils::Check(fi.Read(&num_col, sizeof(num_col)), "MetaInfo: invalid format");
|
utils::Check(fi.Read(&info.num_col, sizeof(info.num_col)), "MetaInfo: invalid format");
|
||||||
utils::Check(fi.Read(&labels), "MetaInfo: invalid format");
|
utils::Check(fi.Read(&labels), "MetaInfo: invalid format");
|
||||||
utils::Check(fi.Read(&group_ptr), "MetaInfo: invalid format");
|
utils::Check(fi.Read(&group_ptr), "MetaInfo: invalid format");
|
||||||
utils::Check(fi.Read(&weights), "MetaInfo: invalid format");
|
utils::Check(fi.Read(&weights), "MetaInfo: invalid format");
|
||||||
@ -94,19 +102,28 @@ struct MetaInfo {
|
|||||||
fclose(fi);
|
fclose(fi);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
inline std::vector<float>& GetInfo(const char *field) {
|
inline std::vector<float>& GetFloatInfo(const char *field) {
|
||||||
if (!strcmp(field, "label")) return labels;
|
if (!strcmp(field, "label")) return labels;
|
||||||
if (!strcmp(field, "weight")) return weights;
|
if (!strcmp(field, "weight")) return weights;
|
||||||
if (!strcmp(field, "base_margin")) return base_margin;
|
if (!strcmp(field, "base_margin")) return base_margin;
|
||||||
utils::Error("unknown field %s", field);
|
utils::Error("unknown field %s", field);
|
||||||
return labels;
|
return labels;
|
||||||
}
|
}
|
||||||
inline const std::vector<float>& GetInfo(const char *field) const {
|
inline const std::vector<float>& GetFloatInfo(const char *field) const {
|
||||||
return ((MetaInfo*)this)->GetInfo(field);
|
return ((MetaInfo*)this)->GetFloatInfo(field);
|
||||||
|
}
|
||||||
|
inline std::vector<unsigned> &GetUIntInfo(const char *field) {
|
||||||
|
if (!strcmp(field, "root_index")) return info.root_index;
|
||||||
|
if (!strcmp(field, "fold_index")) return info.fold_index;
|
||||||
|
utils::Error("unknown field %s", field);
|
||||||
|
return info.root_index;
|
||||||
|
}
|
||||||
|
inline const std::vector<unsigned> &GetUIntInfo(const char *field) const {
|
||||||
|
return ((MetaInfo*)this)->GetUIntInfo(field);
|
||||||
}
|
}
|
||||||
// try to load weight information from file, if exists
|
// try to load weight information from file, if exists
|
||||||
inline bool TryLoadFloatInfo(const char *field, const char* fname, bool silent = false) {
|
inline bool TryLoadFloatInfo(const char *field, const char* fname, bool silent = false) {
|
||||||
std::vector<float> &weights = this->GetInfo(field);
|
std::vector<float> &weights = this->GetFloatInfo(field);
|
||||||
FILE *fi = fopen64(fname, "r");
|
FILE *fi = fopen64(fname, "r");
|
||||||
if (fi == NULL) return false;
|
if (fi == NULL) return false;
|
||||||
float wt;
|
float wt;
|
||||||
|
|||||||
@ -58,9 +58,9 @@ class BoostLearner {
|
|||||||
if (dupilicate) continue;
|
if (dupilicate) continue;
|
||||||
// set mats[i]'s cache learner pointer to this
|
// set mats[i]'s cache learner pointer to this
|
||||||
mats[i]->cache_learner_ptr_ = this;
|
mats[i]->cache_learner_ptr_ = this;
|
||||||
cache_.push_back(CacheEntry(mats[i], buffer_size, mats[i]->info.num_row));
|
cache_.push_back(CacheEntry(mats[i], buffer_size, mats[i]->info.num_row()));
|
||||||
buffer_size += mats[i]->info.num_row;
|
buffer_size += mats[i]->info.num_row();
|
||||||
num_feature = std::max(num_feature, static_cast<unsigned>(mats[i]->info.num_col));
|
num_feature = std::max(num_feature, static_cast<unsigned>(mats[i]->info.num_col()));
|
||||||
}
|
}
|
||||||
char str_temp[25];
|
char str_temp[25];
|
||||||
if (num_feature > mparam.num_feature) {
|
if (num_feature > mparam.num_feature) {
|
||||||
@ -329,7 +329,7 @@ class BoostLearner {
|
|||||||
inline int64_t FindBufferOffset(const DMatrix<FMatrix> &mat) const {
|
inline int64_t FindBufferOffset(const DMatrix<FMatrix> &mat) const {
|
||||||
for (size_t i = 0; i < cache_.size(); ++i) {
|
for (size_t i = 0; i < cache_.size(); ++i) {
|
||||||
if (cache_[i].mat_ == &mat && mat.cache_learner_ptr_ == this) {
|
if (cache_[i].mat_ == &mat && mat.cache_learner_ptr_ == this) {
|
||||||
if (cache_[i].num_row_ == mat.info.num_row) {
|
if (cache_[i].num_row_ == mat.info.num_row()) {
|
||||||
return cache_[i].buffer_offset_;
|
return cache_[i].buffer_offset_;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user