chg root index to booster info, need review

This commit is contained in:
tqchen 2014-08-22 16:26:37 -07:00
parent 58d74861b9
commit 58354643b0
10 changed files with 48 additions and 44 deletions

View File

@ -37,7 +37,7 @@ class Booster: public learner::BoostLearner<FMatrixS> {
for (unsigned j = 0; j < ndata; ++j) { for (unsigned j = 0; j < ndata; ++j) {
gpair_[j] = bst_gpair(grad[j], hess[j]); gpair_[j] = bst_gpair(grad[j], hess[j]);
} }
gbm_->DoBoost(gpair_, train.fmat, train.info.root_index); gbm_->DoBoost(gpair_, train.fmat, train.info.info);
} }
inline void CheckInitModel(void) { inline void CheckInitModel(void) {
if (!init_model) { if (!init_model) {
@ -151,8 +151,8 @@ extern "C"{
if (src.info.weights.size() != 0) { if (src.info.weights.size() != 0) {
ret.info.weights.push_back(src.info.weights[ridx]); ret.info.weights.push_back(src.info.weights[ridx]);
} }
if (src.info.root_index.size() != 0) { if (src.info.info.root_index.size() != 0) {
ret.info.weights.push_back(src.info.root_index[ridx]); ret.info.info.root_index.push_back(src.info.info.root_index[ridx]);
} }
} }
return p_ret; return p_ret;

View File

@ -39,6 +39,24 @@ struct bst_gpair {
bst_gpair(bst_float grad, bst_float hess) : grad(grad), hess(hess) {} bst_gpair(bst_float grad, bst_float hess) : grad(grad), hess(hess) {}
}; };
/*!
* \brief extra information that might needed by gbm and tree module
* these information are not necessarily presented, and can be empty
*/
struct BoosterInfo {
/*!
* \brief specified root index of each instance,
* can be used for multi task setting
*/
std::vector<unsigned> root_index;
/*! \brief set fold indicator */
std::vector<unsigned> fold_index;
/*! \brief get root of ith instance */
inline unsigned GetRoot(size_t i) const {
return root_index.size() == 0 ? 0 : root_index[i];
}
};
/*! \brief read-only sparse instance batch in CSR format */ /*! \brief read-only sparse instance batch in CSR format */
struct SparseBatch { struct SparseBatch {
/*! \brief an entry of sparse vector */ /*! \brief an entry of sparse vector */

View File

@ -43,12 +43,11 @@ class IGradBooster {
* \brief peform update to the model(boosting) * \brief peform update to the model(boosting)
* \param gpair the gradient pair statistics of the data * \param gpair the gradient pair statistics of the data
* \param fmat feature matrix that provide access to features * \param fmat feature matrix that provide access to features
* \param root_index pre-partitioned root_index of each instance, * \param info meta information about training
* root_index.size() can be 0 which indicates that no pre-partition involved
*/ */
virtual void DoBoost(const std::vector<bst_gpair> &gpair, virtual void DoBoost(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat, const FMatrix &fmat,
const std::vector<unsigned> &root_index) = 0; const BoosterInfo &info) = 0;
/*! /*!
* \brief generate predictions for given feature matrix * \brief generate predictions for given feature matrix
* \param fmat feature matrix * \param fmat feature matrix
@ -56,13 +55,12 @@ class IGradBooster {
* this means we do not have buffer index allocated to the gbm * this means we do not have buffer index allocated to the gbm
* a buffer index is assigned to each instance that requires repeative prediction * a buffer index is assigned to each instance that requires repeative prediction
* the size of buffer is set by convention using IGradBooster.SetParam("num_pbuffer","size") * the size of buffer is set by convention using IGradBooster.SetParam("num_pbuffer","size")
* \param root_index pre-partitioned root_index of each instance, * \param info extra side information that may be needed for prediction
* root_index.size() can be 0 which indicates that no pre-partition involved
* \param out_preds output vector to hold the predictions * \param out_preds output vector to hold the predictions
*/ */
virtual void Predict(const FMatrix &fmat, virtual void Predict(const FMatrix &fmat,
int64_t buffer_offset, int64_t buffer_offset,
const std::vector<unsigned> &root_index, const BoosterInfo &info,
std::vector<float> *out_preds) = 0; std::vector<float> *out_preds) = 0;
/*! /*!
* \brief dump the model in text format * \brief dump the model in text format

View File

@ -84,9 +84,9 @@ class GBTree : public IGradBooster<FMatrix> {
} }
virtual void DoBoost(const std::vector<bst_gpair> &gpair, virtual void DoBoost(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat, const FMatrix &fmat,
const std::vector<unsigned> &root_index) { const BoosterInfo &info) {
if (mparam.num_output_group == 1) { if (mparam.num_output_group == 1) {
this->BoostNewTrees(gpair, fmat, root_index, 0); this->BoostNewTrees(gpair, fmat, info, 0);
} else { } else {
const int ngroup = mparam.num_output_group; const int ngroup = mparam.num_output_group;
utils::Check(gpair.size() % ngroup == 0, utils::Check(gpair.size() % ngroup == 0,
@ -97,13 +97,13 @@ class GBTree : public IGradBooster<FMatrix> {
for (size_t i = 0; i < tmp.size(); ++i) { for (size_t i = 0; i < tmp.size(); ++i) {
tmp[i] = gpair[i * ngroup + gid]; tmp[i] = gpair[i * ngroup + gid];
} }
this->BoostNewTrees(tmp, fmat, root_index, gid); this->BoostNewTrees(tmp, fmat, info, gid);
} }
} }
} }
virtual void Predict(const FMatrix &fmat, virtual void Predict(const FMatrix &fmat,
int64_t buffer_offset, int64_t buffer_offset,
const std::vector<unsigned> &root_index, const BoosterInfo &info,
std::vector<float> *out_preds) { std::vector<float> *out_preds) {
int nthread; int nthread;
#pragma omp parallel #pragma omp parallel
@ -134,7 +134,7 @@ class GBTree : public IGradBooster<FMatrix> {
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
tree::RegTree::FVec &feats = thread_temp[tid]; tree::RegTree::FVec &feats = thread_temp[tid];
const size_t ridx = batch.base_rowid + i; const size_t ridx = batch.base_rowid + i;
const unsigned root_idx = root_index.size() == 0 ? 0 : root_index[ridx]; const unsigned root_idx = info.GetRoot(i);
// loop over output groups // loop over output groups
for (int gid = 0; gid < mparam.num_output_group; ++gid) { for (int gid = 0; gid < mparam.num_output_group; ++gid) {
preds[ridx * mparam.num_output_group + gid] = preds[ridx * mparam.num_output_group + gid] =
@ -186,7 +186,7 @@ class GBTree : public IGradBooster<FMatrix> {
// do group specific group // do group specific group
inline void BoostNewTrees(const std::vector<bst_gpair> &gpair, inline void BoostNewTrees(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat, const FMatrix &fmat,
const std::vector<unsigned> &root_index, const BoosterInfo &info,
int bst_group) { int bst_group) {
this->InitUpdater(); this->InitUpdater();
// create the trees // create the trees
@ -200,7 +200,7 @@ class GBTree : public IGradBooster<FMatrix> {
} }
// update the trees // update the trees
for (size_t i = 0; i < updaters.size(); ++i) { for (size_t i = 0; i < updaters.size(); ++i) {
updaters[i]->Update(gpair, fmat, root_index, new_trees); updaters[i]->Update(gpair, fmat, info, new_trees);
} }
// push back to model // push back to model
for (size_t i = 0; i < new_trees.size(); ++i) { for (size_t i = 0; i < new_trees.size(); ++i) {

View File

@ -28,11 +28,8 @@ struct MetaInfo {
std::vector<bst_uint> group_ptr; std::vector<bst_uint> group_ptr;
/*! \brief weights of each instance, optional */ /*! \brief weights of each instance, optional */
std::vector<float> weights; std::vector<float> weights;
/*! /*! \brief information needed by booster */
* \brief specified root index of each instance, BoosterInfo info;
* can be used for multi task setting
*/
std::vector<unsigned> root_index;
/*! /*!
* \brief initialized margins, * \brief initialized margins,
* if specified, xgboost will start from this init margin * if specified, xgboost will start from this init margin
@ -48,7 +45,7 @@ struct MetaInfo {
labels.clear(); labels.clear();
group_ptr.clear(); group_ptr.clear();
weights.clear(); weights.clear();
root_index.clear(); info.root_index.clear();
base_margin.clear(); base_margin.clear();
num_row = num_col = 0; num_row = num_col = 0;
} }
@ -60,14 +57,6 @@ struct MetaInfo {
return 1.0f; return 1.0f;
} }
} }
/*! \brief get root index of i-th instance */
inline float GetRoot(size_t i) const {
if (root_index.size() != 0) {
return static_cast<float>(root_index[i]);
} else {
return 0;
}
}
inline void SaveBinary(utils::IStream &fo) const { inline void SaveBinary(utils::IStream &fo) const {
int version = kVersion; int version = kVersion;
fo.Write(&version, sizeof(version)); fo.Write(&version, sizeof(version));
@ -76,7 +65,7 @@ struct MetaInfo {
fo.Write(labels); fo.Write(labels);
fo.Write(group_ptr); fo.Write(group_ptr);
fo.Write(weights); fo.Write(weights);
fo.Write(root_index); fo.Write(info.root_index);
fo.Write(base_margin); fo.Write(base_margin);
} }
inline void LoadBinary(utils::IStream &fi) { inline void LoadBinary(utils::IStream &fi) {
@ -87,7 +76,7 @@ struct MetaInfo {
utils::Check(fi.Read(&labels), "MetaInfo: invalid format"); utils::Check(fi.Read(&labels), "MetaInfo: invalid format");
utils::Check(fi.Read(&group_ptr), "MetaInfo: invalid format"); utils::Check(fi.Read(&group_ptr), "MetaInfo: invalid format");
utils::Check(fi.Read(&weights), "MetaInfo: invalid format"); utils::Check(fi.Read(&weights), "MetaInfo: invalid format");
utils::Check(fi.Read(&root_index), "MetaInfo: invalid format"); utils::Check(fi.Read(&info.root_index), "MetaInfo: invalid format");
utils::Check(fi.Read(&base_margin), "MetaInfo: invalid format"); utils::Check(fi.Read(&base_margin), "MetaInfo: invalid format");
} }
// try to load group information from file, if exists // try to load group information from file, if exists

View File

@ -161,7 +161,7 @@ class BoostLearner {
inline void UpdateOneIter(int iter, const DMatrix<FMatrix> &train) { inline void UpdateOneIter(int iter, const DMatrix<FMatrix> &train) {
this->PredictRaw(train, &preds_); this->PredictRaw(train, &preds_);
obj_->GetGradient(preds_, train.info, iter, &gpair_); obj_->GetGradient(preds_, train.info, iter, &gpair_);
gbm_->DoBoost(gpair_, train.fmat, train.info.root_index); gbm_->DoBoost(gpair_, train.fmat, train.info.info);
} }
/*! /*!
* \brief evaluate the model for specific iteration * \brief evaluate the model for specific iteration
@ -242,7 +242,7 @@ class BoostLearner {
inline void PredictRaw(const DMatrix<FMatrix> &data, inline void PredictRaw(const DMatrix<FMatrix> &data,
std::vector<float> *out_preds) const { std::vector<float> *out_preds) const {
gbm_->Predict(data.fmat, this->FindBufferOffset(data), gbm_->Predict(data.fmat, this->FindBufferOffset(data),
data.info.root_index, out_preds); data.info.info, out_preds);
// add base margin // add base margin
std::vector<float> &preds = *out_preds; std::vector<float> &preds = *out_preds;
const unsigned ndata = static_cast<unsigned>(preds.size()); const unsigned ndata = static_cast<unsigned>(preds.size());

View File

@ -29,8 +29,7 @@ class IUpdater {
* \brief peform update to the tree models * \brief peform update to the tree models
* \param gpair the gradient pair statistics of the data * \param gpair the gradient pair statistics of the data
* \param fmat feature matrix that provide access to features * \param fmat feature matrix that provide access to features
* \param root_index pre-partitioned root_index of each instance, * \param info extra side information that may be need, such as root index
* root_index.size() can be 0 which indicates that no pre-partition involved
* \param trees pointer to the trese to be updated, upater will change the content of the tree * \param trees pointer to the trese to be updated, upater will change the content of the tree
* note: all the trees in the vector are updated, with the same statistics, * note: all the trees in the vector are updated, with the same statistics,
* but maybe different random seeds, usually one tree is passed in at a time, * but maybe different random seeds, usually one tree is passed in at a time,
@ -38,7 +37,7 @@ class IUpdater {
*/ */
virtual void Update(const std::vector<bst_gpair> &gpair, virtual void Update(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat, const FMatrix &fmat,
const std::vector<unsigned> &root_index, const BoosterInfo &info,
const std::vector<RegTree*> &trees) = 0; const std::vector<RegTree*> &trees) = 0;
// destructor // destructor
virtual ~IUpdater(void) {} virtual ~IUpdater(void) {}

View File

@ -25,7 +25,7 @@ class ColMaker: public IUpdater<FMatrix> {
} }
virtual void Update(const std::vector<bst_gpair> &gpair, virtual void Update(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat, const FMatrix &fmat,
const std::vector<unsigned> &root_index, const BoosterInfo &info,
const std::vector<RegTree*> &trees) { const std::vector<RegTree*> &trees) {
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
float lr = param.learning_rate; float lr = param.learning_rate;
@ -33,7 +33,7 @@ class ColMaker: public IUpdater<FMatrix> {
// build tree // build tree
for (size_t i = 0; i < trees.size(); ++i) { for (size_t i = 0; i < trees.size(); ++i) {
Builder builder(param); Builder builder(param);
builder.Update(gpair, fmat, root_index, trees[i]); builder.Update(gpair, fmat, info, trees[i]);
} }
param.learning_rate = lr; param.learning_rate = lr;
} }
@ -77,9 +77,9 @@ class ColMaker: public IUpdater<FMatrix> {
// update one tree, growing // update one tree, growing
virtual void Update(const std::vector<bst_gpair> &gpair, virtual void Update(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat, const FMatrix &fmat,
const std::vector<unsigned> &root_index, const BoosterInfo &info,
RegTree *p_tree) { RegTree *p_tree) {
this->InitData(gpair, fmat, root_index, *p_tree); this->InitData(gpair, fmat, info.root_index, *p_tree);
this->InitNewNode(qexpand, gpair, fmat, *p_tree); this->InitNewNode(qexpand, gpair, fmat, *p_tree);
for (int depth = 0; depth < param.max_depth; ++depth) { for (int depth = 0; depth < param.max_depth; ++depth) {

View File

@ -24,7 +24,7 @@ class TreePruner: public IUpdater<FMatrix> {
// update the tree, do pruning // update the tree, do pruning
virtual void Update(const std::vector<bst_gpair> &gpair, virtual void Update(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat, const FMatrix &fmat,
const std::vector<unsigned> &root_index, const BoosterInfo &info,
const std::vector<RegTree*> &trees) { const std::vector<RegTree*> &trees) {
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
float lr = param.learning_rate; float lr = param.learning_rate;

View File

@ -24,7 +24,7 @@ class TreeRefresher: public IUpdater<FMatrix> {
// update the tree, do pruning // update the tree, do pruning
virtual void Update(const std::vector<bst_gpair> &gpair, virtual void Update(const std::vector<bst_gpair> &gpair,
const FMatrix &fmat, const FMatrix &fmat,
const std::vector<unsigned> &root_index, const BoosterInfo &info,
const std::vector<RegTree*> &trees) { const std::vector<RegTree*> &trees) {
if (trees.size() == 0) return; if (trees.size() == 0) return;
// number of threads // number of threads
@ -66,7 +66,7 @@ class TreeRefresher: public IUpdater<FMatrix> {
feats.Fill(inst); feats.Fill(inst);
for (size_t j = 0; j < trees.size(); ++j) { for (size_t j = 0; j < trees.size(); ++j) {
AddStats(*trees[j], feats, gpair[ridx], AddStats(*trees[j], feats, gpair[ridx],
root_index.size() == 0 ? 0 : root_index[ridx], info.GetRoot(j),
&stemp[tid * trees.size() + j]); &stemp[tid * trees.size() + j]);
} }
feats.Drop(inst); feats.Drop(inst);