chg root index to booster info, need review

This commit is contained in:
tqchen
2014-08-22 16:26:37 -07:00
parent 58d74861b9
commit 58354643b0
10 changed files with 48 additions and 44 deletions

View File

@@ -28,11 +28,8 @@ struct MetaInfo {
std::vector<bst_uint> group_ptr;
/*! \brief weights of each instance, optional */
std::vector<float> weights;
/*!
* \brief specified root index of each instance,
* can be used for multi task setting
*/
std::vector<unsigned> root_index;
/*! \brief information needed by booster */
BoosterInfo info;
/*!
* \brief initialized margins,
* if specified, xgboost will start from this init margin
@@ -48,7 +45,7 @@ struct MetaInfo {
labels.clear();
group_ptr.clear();
weights.clear();
root_index.clear();
info.root_index.clear();
base_margin.clear();
num_row = num_col = 0;
}
@@ -60,14 +57,6 @@ struct MetaInfo {
return 1.0f;
}
}
/*! \brief get root index of i-th instance */
inline float GetRoot(size_t i) const {
if (root_index.size() != 0) {
return static_cast<float>(root_index[i]);
} else {
return 0;
}
}
inline void SaveBinary(utils::IStream &fo) const {
int version = kVersion;
fo.Write(&version, sizeof(version));
@@ -76,7 +65,7 @@ struct MetaInfo {
fo.Write(labels);
fo.Write(group_ptr);
fo.Write(weights);
fo.Write(root_index);
fo.Write(info.root_index);
fo.Write(base_margin);
}
inline void LoadBinary(utils::IStream &fi) {
@@ -87,7 +76,7 @@ struct MetaInfo {
utils::Check(fi.Read(&labels), "MetaInfo: invalid format");
utils::Check(fi.Read(&group_ptr), "MetaInfo: invalid format");
utils::Check(fi.Read(&weights), "MetaInfo: invalid format");
utils::Check(fi.Read(&root_index), "MetaInfo: invalid format");
utils::Check(fi.Read(&info.root_index), "MetaInfo: invalid format");
utils::Check(fi.Read(&base_margin), "MetaInfo: invalid format");
}
// try to load group information from file, if exists

View File

@@ -161,7 +161,7 @@ class BoostLearner {
inline void UpdateOneIter(int iter, const DMatrix<FMatrix> &train) {
this->PredictRaw(train, &preds_);
obj_->GetGradient(preds_, train.info, iter, &gpair_);
gbm_->DoBoost(gpair_, train.fmat, train.info.root_index);
gbm_->DoBoost(gpair_, train.fmat, train.info.info);
}
/*!
* \brief evaluate the model for specific iteration
@@ -242,7 +242,7 @@ class BoostLearner {
inline void PredictRaw(const DMatrix<FMatrix> &data,
std::vector<float> *out_preds) const {
gbm_->Predict(data.fmat, this->FindBufferOffset(data),
data.info.root_index, out_preds);
data.info.info, out_preds);
// add base margin
std::vector<float> &preds = *out_preds;
const unsigned ndata = static_cast<unsigned>(preds.size());