From deb21351b976aec90454aa2cef8c0d62dca5edf8 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 20 Dec 2014 01:05:40 -0800 Subject: [PATCH] add rabit checkpoint to xgb --- src/gbm/gblinear-inl.hpp | 4 ++-- src/gbm/gbm.h | 6 +++-- src/gbm/gbtree-inl.hpp | 20 +++++++++++------ src/learner/learner-inl.hpp | 45 ++++++++++++++++++++++++++++++------- src/xgboost_main.cpp | 37 +++++++++++++++++++++--------- 5 files changed, 83 insertions(+), 29 deletions(-) diff --git a/src/gbm/gblinear-inl.hpp b/src/gbm/gblinear-inl.hpp index 6d507ac6e..005eada55 100644 --- a/src/gbm/gblinear-inl.hpp +++ b/src/gbm/gblinear-inl.hpp @@ -32,10 +32,10 @@ class GBLinear : public IGradBooster { model.param.SetParam(name, val); } } - virtual void LoadModel(utils::IStream &fi) { + virtual void LoadModel(utils::IStream &fi, bool with_pbuffer) { model.LoadModel(fi); } - virtual void SaveModel(utils::IStream &fo) const { + virtual void SaveModel(utils::IStream &fo, bool with_pbuffer) const { model.SaveModel(fo); } virtual void InitModel(void) { diff --git a/src/gbm/gbm.h b/src/gbm/gbm.h index f8eae6dbb..8799a7af0 100644 --- a/src/gbm/gbm.h +++ b/src/gbm/gbm.h @@ -27,13 +27,15 @@ class IGradBooster { /*! * \brief load model from stream * \param fi input stream + * \param with_pbuffer whether the incoming data contains pbuffer */ - virtual void LoadModel(utils::IStream &fi) = 0; + virtual void LoadModel(utils::IStream &fi, bool with_pbuffer) = 0; /*! * \brief save model to stream * \param fo output stream + * \param with_pbuffer whether save out pbuffer */ - virtual void SaveModel(utils::IStream &fo) const = 0; + virtual void SaveModel(utils::IStream &fo, bool with_pbuffer) const = 0; /*! * \brief initialize the model */ diff --git a/src/gbm/gbtree-inl.hpp b/src/gbm/gbtree-inl.hpp index 8d511f06e..e63ea42fa 100644 --- a/src/gbm/gbtree-inl.hpp +++ b/src/gbm/gbtree-inl.hpp @@ -39,7 +39,7 @@ class GBTree : public IGradBooster { tparam.SetParam(name, val); if (trees.size() == 0) mparam.SetParam(name, val); } - virtual void LoadModel(utils::IStream &fi) { + virtual void LoadModel(utils::IStream &fi, bool with_pbuffer) { this->Clear(); utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0, "GBTree: invalid model file"); @@ -56,13 +56,19 @@ class GBTree : public IGradBooster { if (mparam.num_pbuffer != 0) { pred_buffer.resize(mparam.PredBufferSize()); pred_counter.resize(mparam.PredBufferSize()); - utils::Check(fi.Read(&pred_buffer[0], pred_buffer.size() * sizeof(float)) != 0, - "GBTree: invalid model file"); - utils::Check(fi.Read(&pred_counter[0], pred_counter.size() * sizeof(unsigned)) != 0, - "GBTree: invalid model file"); + if (with_pbuffer) { + utils::Check(fi.Read(&pred_buffer[0], pred_buffer.size() * sizeof(float)) != 0, + "GBTree: invalid model file"); + utils::Check(fi.Read(&pred_counter[0], pred_counter.size() * sizeof(unsigned)) != 0, + "GBTree: invalid model file"); + } else { + // reset predict buffer if the input do not have them + std::fill(pred_buffer.begin(), pred_buffer.end(), 0.0f); + std::fill(pred_counter.begin(), pred_counter.end(), 0); + } } } - virtual void SaveModel(utils::IStream &fo) const { + virtual void SaveModel(utils::IStream &fo, bool with_pbuffer) const { utils::Assert(mparam.num_trees == static_cast(trees.size()), "GBTree"); fo.Write(&mparam, sizeof(ModelParam)); for (size_t i = 0; i < trees.size(); ++i) { @@ -71,7 +77,7 @@ class GBTree : public IGradBooster { if (tree_info.size() != 0) { fo.Write(&tree_info[0], sizeof(int) * tree_info.size()); } - if (mparam.num_pbuffer != 0) { + if (mparam.num_pbuffer != 0 && with_pbuffer) { fo.Write(&pred_buffer[0], pred_buffer.size() * sizeof(float)); fo.Write(&pred_counter[0], pred_counter.size() * sizeof(unsigned)); } diff --git a/src/learner/learner-inl.hpp b/src/learner/learner-inl.hpp index 6ca3b7c7a..1640071b6 100644 --- a/src/learner/learner-inl.hpp +++ b/src/learner/learner-inl.hpp @@ -23,7 +23,7 @@ namespace learner { * \brief learner that takes do gradient boosting on specific objective functions * and do training and prediction */ -class BoostLearner { +class BoostLearner : public rabit::ISerializable { public: BoostLearner(void) { obj_ = NULL; @@ -35,7 +35,7 @@ class BoostLearner { distributed_mode = 0; pred_buffer_size = 0; } - ~BoostLearner(void) { + virtual ~BoostLearner(void) { if (obj_ != NULL) delete obj_; if (gbm_ != NULL) delete gbm_; } @@ -140,9 +140,9 @@ class BoostLearner { /*! * \brief load model from stream * \param fi input stream - * \param keep_predbuffer whether to keep predict buffer + * \param with_pbuffer whether to load with predict buffer */ - inline void LoadModel(utils::IStream &fi, bool keep_predbuffer = true) { + inline void LoadModel(utils::IStream &fi, bool with_pbuffer = true) { utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0, "BoostLearner: wrong model format"); utils::Check(fi.Read(&name_obj_), "BoostLearner: wrong model format"); @@ -151,11 +151,23 @@ class BoostLearner { if (obj_ != NULL) delete obj_; if (gbm_ != NULL) delete gbm_; this->InitObjGBM(); - gbm_->LoadModel(fi); - if (keep_predbuffer && distributed_mode == 2 && rabit::GetRank() != 0) { + gbm_->LoadModel(fi, with_pbuffer); + if (with_pbuffer && distributed_mode == 2 && rabit::GetRank() != 0) { gbm_->ResetPredBuffer(pred_buffer_size); } } + // rabit load model from rabit checkpoint + virtual void Load(rabit::IStream &fi) { + RabitStreamAdapter fs(fi); + // for row split, we should not keep pbuffer + this->LoadModel(fs, distributed_mode != 2); + } + // rabit save model to rabit checkpoint + virtual void Save(rabit::IStream &fo) const { + RabitStreamAdapter fs(fo); + // for row split, we should not keep pbuffer + this->SaveModel(fs, distributed_mode != 2); + } /*! * \brief load model from file * \param fname file name @@ -165,11 +177,11 @@ class BoostLearner { this->LoadModel(fi); fi.Close(); } - inline void SaveModel(utils::IStream &fo) const { + inline void SaveModel(utils::IStream &fo, bool with_pbuffer = true) const { fo.Write(&mparam, sizeof(ModelParam)); fo.Write(name_obj_); fo.Write(name_gbm_); - gbm_->SaveModel(fo); + gbm_->SaveModel(fo, with_pbuffer); } /*! * \brief save model into file @@ -394,6 +406,23 @@ class BoostLearner { // data structure field /*! \brief the entries indicates that we have internal prediction cache */ std::vector cache_; + + private: + // adapt rabit stream to utils stream + struct RabitStreamAdapter : public utils::IStream { + // rabit stream + rabit::IStream &fs; + // constructr + RabitStreamAdapter(rabit::IStream &fs) : fs(fs) {} + // destructor + virtual ~RabitStreamAdapter(void){} + virtual size_t Read(void *ptr, size_t size) { + return fs.Read(ptr, size); + } + virtual void Write(const void *ptr, size_t size) { + fs.Write(ptr, size); + } + }; }; } // namespace learner } // namespace xgboost diff --git a/src/xgboost_main.cpp b/src/xgboost_main.cpp index 9583a2278..d25140461 100644 --- a/src/xgboost_main.cpp +++ b/src/xgboost_main.cpp @@ -31,14 +31,32 @@ class BoostLearnTask { this->SetParam(name, val); } } + // whether need data rank + bool need_data_rank = strchr(train_path.c_str(), '%') != NULL; + // if need data rank in loading, initialize rabit engine before load data + // otherwise, initialize rabit engine after loading data + // lazy initialization of rabit engine can be helpful in speculative execution + if (need_data_rank) rabit::Init(argc, argv); + this->InitData(); + if (!need_data_rank) rabit::Init(argc, argv); + if (rabit::IsDistributed()) { + std::string pname = rabit::GetProcessorName(); + printf("start %s:%d\n", pname.c_str(), rabit::GetRank()); + } if (rabit::IsDistributed()) { this->SetParam("data_split", "col"); } if (rabit::GetRank() != 0) { this->SetParam("silent", "2"); } - this->InitData(); - this->InitLearner(); + + if (task == "train") { + // if task is training, will try recover from checkpoint + this->TaskTrain(); + return 0; + } else { + this->InitLearner(); + } if (task == "dump") { this->TaskDump(); return 0; } @@ -47,8 +65,6 @@ class BoostLearnTask { } if (task == "pred") { this->TaskPred(); - } else { - this->TaskTrain(); } return 0; } @@ -152,10 +168,13 @@ class BoostLearnTask { } } inline void TaskTrain(void) { + int version = rabit::LoadCheckPoint(&learner); + if (version == 0) this->InitLearner(); + const time_t start = time(NULL); unsigned long elapsed = 0; learner.CheckInit(data); - for (int i = 0; i < num_round; ++i) { + for (int i = version; i < num_round; ++i) { elapsed = (unsigned long)(time(NULL) - start); if (!silent) printf("boosting round %d, %lu sec elapsed\n", i, elapsed); learner.UpdateOneIter(i, *data); @@ -166,6 +185,9 @@ class BoostLearnTask { if (save_period != 0 && (i + 1) % save_period == 0) { this->SaveModel(i); } + utils::Assert(rabit::VersionNumber() == i, "incorrect version number"); + // checkpoint the model + rabit::CheckPoint(&learner); elapsed = (unsigned long)(time(NULL) - start); } // always save final round @@ -263,11 +285,6 @@ class BoostLearnTask { } int main(int argc, char *argv[]){ - rabit::Init(argc, argv); - if (rabit::IsDistributed()) { - std::string pname = rabit::GetProcessorName(); - printf("start %s:%d\n", pname.c_str(), rabit::GetRank()); - } xgboost::random::Seed(0); xgboost::BoostLearnTask tsk; int ret = tsk.Run(argc, argv);