chg root index to booster info, need review
This commit is contained in:
@@ -28,11 +28,8 @@ struct MetaInfo {
|
||||
std::vector<bst_uint> group_ptr;
|
||||
/*! \brief weights of each instance, optional */
|
||||
std::vector<float> weights;
|
||||
/*!
|
||||
* \brief specified root index of each instance,
|
||||
* can be used for multi task setting
|
||||
*/
|
||||
std::vector<unsigned> root_index;
|
||||
/*! \brief information needed by booster */
|
||||
BoosterInfo info;
|
||||
/*!
|
||||
* \brief initialized margins,
|
||||
* if specified, xgboost will start from this init margin
|
||||
@@ -48,7 +45,7 @@ struct MetaInfo {
|
||||
labels.clear();
|
||||
group_ptr.clear();
|
||||
weights.clear();
|
||||
root_index.clear();
|
||||
info.root_index.clear();
|
||||
base_margin.clear();
|
||||
num_row = num_col = 0;
|
||||
}
|
||||
@@ -60,14 +57,6 @@ struct MetaInfo {
|
||||
return 1.0f;
|
||||
}
|
||||
}
|
||||
/*! \brief get root index of i-th instance */
|
||||
inline float GetRoot(size_t i) const {
|
||||
if (root_index.size() != 0) {
|
||||
return static_cast<float>(root_index[i]);
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
inline void SaveBinary(utils::IStream &fo) const {
|
||||
int version = kVersion;
|
||||
fo.Write(&version, sizeof(version));
|
||||
@@ -76,7 +65,7 @@ struct MetaInfo {
|
||||
fo.Write(labels);
|
||||
fo.Write(group_ptr);
|
||||
fo.Write(weights);
|
||||
fo.Write(root_index);
|
||||
fo.Write(info.root_index);
|
||||
fo.Write(base_margin);
|
||||
}
|
||||
inline void LoadBinary(utils::IStream &fi) {
|
||||
@@ -87,7 +76,7 @@ struct MetaInfo {
|
||||
utils::Check(fi.Read(&labels), "MetaInfo: invalid format");
|
||||
utils::Check(fi.Read(&group_ptr), "MetaInfo: invalid format");
|
||||
utils::Check(fi.Read(&weights), "MetaInfo: invalid format");
|
||||
utils::Check(fi.Read(&root_index), "MetaInfo: invalid format");
|
||||
utils::Check(fi.Read(&info.root_index), "MetaInfo: invalid format");
|
||||
utils::Check(fi.Read(&base_margin), "MetaInfo: invalid format");
|
||||
}
|
||||
// try to load group information from file, if exists
|
||||
|
||||
@@ -161,7 +161,7 @@ class BoostLearner {
|
||||
inline void UpdateOneIter(int iter, const DMatrix<FMatrix> &train) {
|
||||
this->PredictRaw(train, &preds_);
|
||||
obj_->GetGradient(preds_, train.info, iter, &gpair_);
|
||||
gbm_->DoBoost(gpair_, train.fmat, train.info.root_index);
|
||||
gbm_->DoBoost(gpair_, train.fmat, train.info.info);
|
||||
}
|
||||
/*!
|
||||
* \brief evaluate the model for specific iteration
|
||||
@@ -242,7 +242,7 @@ class BoostLearner {
|
||||
inline void PredictRaw(const DMatrix<FMatrix> &data,
|
||||
std::vector<float> *out_preds) const {
|
||||
gbm_->Predict(data.fmat, this->FindBufferOffset(data),
|
||||
data.info.root_index, out_preds);
|
||||
data.info.info, out_preds);
|
||||
// add base margin
|
||||
std::vector<float> &preds = *out_preds;
|
||||
const unsigned ndata = static_cast<unsigned>(preds.size());
|
||||
|
||||
Reference in New Issue
Block a user