add rabit checkpoint to xgb

This commit is contained in:
tqchen 2014-12-20 01:05:40 -08:00
parent 8e16cc4617
commit deb21351b9
5 changed files with 83 additions and 29 deletions

View File

@ -32,10 +32,10 @@ class GBLinear : public IGradBooster {
model.param.SetParam(name, val); model.param.SetParam(name, val);
} }
} }
virtual void LoadModel(utils::IStream &fi) { virtual void LoadModel(utils::IStream &fi, bool with_pbuffer) {
model.LoadModel(fi); model.LoadModel(fi);
} }
virtual void SaveModel(utils::IStream &fo) const { virtual void SaveModel(utils::IStream &fo, bool with_pbuffer) const {
model.SaveModel(fo); model.SaveModel(fo);
} }
virtual void InitModel(void) { virtual void InitModel(void) {

View File

@ -27,13 +27,15 @@ class IGradBooster {
/*! /*!
* \brief load model from stream * \brief load model from stream
* \param fi input stream * \param fi input stream
* \param with_pbuffer whether the incoming data contains pbuffer
*/ */
virtual void LoadModel(utils::IStream &fi) = 0; virtual void LoadModel(utils::IStream &fi, bool with_pbuffer) = 0;
/*! /*!
* \brief save model to stream * \brief save model to stream
* \param fo output stream * \param fo output stream
* \param with_pbuffer whether save out pbuffer
*/ */
virtual void SaveModel(utils::IStream &fo) const = 0; virtual void SaveModel(utils::IStream &fo, bool with_pbuffer) const = 0;
/*! /*!
* \brief initialize the model * \brief initialize the model
*/ */

View File

@ -39,7 +39,7 @@ class GBTree : public IGradBooster {
tparam.SetParam(name, val); tparam.SetParam(name, val);
if (trees.size() == 0) mparam.SetParam(name, val); if (trees.size() == 0) mparam.SetParam(name, val);
} }
virtual void LoadModel(utils::IStream &fi) { virtual void LoadModel(utils::IStream &fi, bool with_pbuffer) {
this->Clear(); this->Clear();
utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0, utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0,
"GBTree: invalid model file"); "GBTree: invalid model file");
@ -56,13 +56,19 @@ class GBTree : public IGradBooster {
if (mparam.num_pbuffer != 0) { if (mparam.num_pbuffer != 0) {
pred_buffer.resize(mparam.PredBufferSize()); pred_buffer.resize(mparam.PredBufferSize());
pred_counter.resize(mparam.PredBufferSize()); pred_counter.resize(mparam.PredBufferSize());
utils::Check(fi.Read(&pred_buffer[0], pred_buffer.size() * sizeof(float)) != 0, if (with_pbuffer) {
"GBTree: invalid model file"); utils::Check(fi.Read(&pred_buffer[0], pred_buffer.size() * sizeof(float)) != 0,
utils::Check(fi.Read(&pred_counter[0], pred_counter.size() * sizeof(unsigned)) != 0, "GBTree: invalid model file");
"GBTree: invalid model file"); utils::Check(fi.Read(&pred_counter[0], pred_counter.size() * sizeof(unsigned)) != 0,
"GBTree: invalid model file");
} else {
// reset predict buffer if the input do not have them
std::fill(pred_buffer.begin(), pred_buffer.end(), 0.0f);
std::fill(pred_counter.begin(), pred_counter.end(), 0);
}
} }
} }
virtual void SaveModel(utils::IStream &fo) const { virtual void SaveModel(utils::IStream &fo, bool with_pbuffer) const {
utils::Assert(mparam.num_trees == static_cast<int>(trees.size()), "GBTree"); utils::Assert(mparam.num_trees == static_cast<int>(trees.size()), "GBTree");
fo.Write(&mparam, sizeof(ModelParam)); fo.Write(&mparam, sizeof(ModelParam));
for (size_t i = 0; i < trees.size(); ++i) { for (size_t i = 0; i < trees.size(); ++i) {
@ -71,7 +77,7 @@ class GBTree : public IGradBooster {
if (tree_info.size() != 0) { if (tree_info.size() != 0) {
fo.Write(&tree_info[0], sizeof(int) * tree_info.size()); fo.Write(&tree_info[0], sizeof(int) * tree_info.size());
} }
if (mparam.num_pbuffer != 0) { if (mparam.num_pbuffer != 0 && with_pbuffer) {
fo.Write(&pred_buffer[0], pred_buffer.size() * sizeof(float)); fo.Write(&pred_buffer[0], pred_buffer.size() * sizeof(float));
fo.Write(&pred_counter[0], pred_counter.size() * sizeof(unsigned)); fo.Write(&pred_counter[0], pred_counter.size() * sizeof(unsigned));
} }

View File

@ -23,7 +23,7 @@ namespace learner {
* \brief learner that takes do gradient boosting on specific objective functions * \brief learner that takes do gradient boosting on specific objective functions
* and do training and prediction * and do training and prediction
*/ */
class BoostLearner { class BoostLearner : public rabit::ISerializable {
public: public:
BoostLearner(void) { BoostLearner(void) {
obj_ = NULL; obj_ = NULL;
@ -35,7 +35,7 @@ class BoostLearner {
distributed_mode = 0; distributed_mode = 0;
pred_buffer_size = 0; pred_buffer_size = 0;
} }
~BoostLearner(void) { virtual ~BoostLearner(void) {
if (obj_ != NULL) delete obj_; if (obj_ != NULL) delete obj_;
if (gbm_ != NULL) delete gbm_; if (gbm_ != NULL) delete gbm_;
} }
@ -140,9 +140,9 @@ class BoostLearner {
/*! /*!
* \brief load model from stream * \brief load model from stream
* \param fi input stream * \param fi input stream
* \param keep_predbuffer whether to keep predict buffer * \param with_pbuffer whether to load with predict buffer
*/ */
inline void LoadModel(utils::IStream &fi, bool keep_predbuffer = true) { inline void LoadModel(utils::IStream &fi, bool with_pbuffer = true) {
utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0, utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0,
"BoostLearner: wrong model format"); "BoostLearner: wrong model format");
utils::Check(fi.Read(&name_obj_), "BoostLearner: wrong model format"); utils::Check(fi.Read(&name_obj_), "BoostLearner: wrong model format");
@ -151,11 +151,23 @@ class BoostLearner {
if (obj_ != NULL) delete obj_; if (obj_ != NULL) delete obj_;
if (gbm_ != NULL) delete gbm_; if (gbm_ != NULL) delete gbm_;
this->InitObjGBM(); this->InitObjGBM();
gbm_->LoadModel(fi); gbm_->LoadModel(fi, with_pbuffer);
if (keep_predbuffer && distributed_mode == 2 && rabit::GetRank() != 0) { if (with_pbuffer && distributed_mode == 2 && rabit::GetRank() != 0) {
gbm_->ResetPredBuffer(pred_buffer_size); gbm_->ResetPredBuffer(pred_buffer_size);
} }
} }
// rabit load model from rabit checkpoint
virtual void Load(rabit::IStream &fi) {
RabitStreamAdapter fs(fi);
// for row split, we should not keep pbuffer
this->LoadModel(fs, distributed_mode != 2);
}
// rabit save model to rabit checkpoint
virtual void Save(rabit::IStream &fo) const {
RabitStreamAdapter fs(fo);
// for row split, we should not keep pbuffer
this->SaveModel(fs, distributed_mode != 2);
}
/*! /*!
* \brief load model from file * \brief load model from file
* \param fname file name * \param fname file name
@ -165,11 +177,11 @@ class BoostLearner {
this->LoadModel(fi); this->LoadModel(fi);
fi.Close(); fi.Close();
} }
inline void SaveModel(utils::IStream &fo) const { inline void SaveModel(utils::IStream &fo, bool with_pbuffer = true) const {
fo.Write(&mparam, sizeof(ModelParam)); fo.Write(&mparam, sizeof(ModelParam));
fo.Write(name_obj_); fo.Write(name_obj_);
fo.Write(name_gbm_); fo.Write(name_gbm_);
gbm_->SaveModel(fo); gbm_->SaveModel(fo, with_pbuffer);
} }
/*! /*!
* \brief save model into file * \brief save model into file
@ -394,6 +406,23 @@ class BoostLearner {
// data structure field // data structure field
/*! \brief the entries indicates that we have internal prediction cache */ /*! \brief the entries indicates that we have internal prediction cache */
std::vector<CacheEntry> cache_; std::vector<CacheEntry> cache_;
private:
// adapt rabit stream to utils stream
struct RabitStreamAdapter : public utils::IStream {
// rabit stream
rabit::IStream &fs;
// constructr
RabitStreamAdapter(rabit::IStream &fs) : fs(fs) {}
// destructor
virtual ~RabitStreamAdapter(void){}
virtual size_t Read(void *ptr, size_t size) {
return fs.Read(ptr, size);
}
virtual void Write(const void *ptr, size_t size) {
fs.Write(ptr, size);
}
};
}; };
} // namespace learner } // namespace learner
} // namespace xgboost } // namespace xgboost

View File

@ -31,14 +31,32 @@ class BoostLearnTask {
this->SetParam(name, val); this->SetParam(name, val);
} }
} }
// whether need data rank
bool need_data_rank = strchr(train_path.c_str(), '%') != NULL;
// if need data rank in loading, initialize rabit engine before load data
// otherwise, initialize rabit engine after loading data
// lazy initialization of rabit engine can be helpful in speculative execution
if (need_data_rank) rabit::Init(argc, argv);
this->InitData();
if (!need_data_rank) rabit::Init(argc, argv);
if (rabit::IsDistributed()) {
std::string pname = rabit::GetProcessorName();
printf("start %s:%d\n", pname.c_str(), rabit::GetRank());
}
if (rabit::IsDistributed()) { if (rabit::IsDistributed()) {
this->SetParam("data_split", "col"); this->SetParam("data_split", "col");
} }
if (rabit::GetRank() != 0) { if (rabit::GetRank() != 0) {
this->SetParam("silent", "2"); this->SetParam("silent", "2");
} }
this->InitData();
this->InitLearner(); if (task == "train") {
// if task is training, will try recover from checkpoint
this->TaskTrain();
return 0;
} else {
this->InitLearner();
}
if (task == "dump") { if (task == "dump") {
this->TaskDump(); return 0; this->TaskDump(); return 0;
} }
@ -47,8 +65,6 @@ class BoostLearnTask {
} }
if (task == "pred") { if (task == "pred") {
this->TaskPred(); this->TaskPred();
} else {
this->TaskTrain();
} }
return 0; return 0;
} }
@ -152,10 +168,13 @@ class BoostLearnTask {
} }
} }
inline void TaskTrain(void) { inline void TaskTrain(void) {
int version = rabit::LoadCheckPoint(&learner);
if (version == 0) this->InitLearner();
const time_t start = time(NULL); const time_t start = time(NULL);
unsigned long elapsed = 0; unsigned long elapsed = 0;
learner.CheckInit(data); learner.CheckInit(data);
for (int i = 0; i < num_round; ++i) { for (int i = version; i < num_round; ++i) {
elapsed = (unsigned long)(time(NULL) - start); elapsed = (unsigned long)(time(NULL) - start);
if (!silent) printf("boosting round %d, %lu sec elapsed\n", i, elapsed); if (!silent) printf("boosting round %d, %lu sec elapsed\n", i, elapsed);
learner.UpdateOneIter(i, *data); learner.UpdateOneIter(i, *data);
@ -166,6 +185,9 @@ class BoostLearnTask {
if (save_period != 0 && (i + 1) % save_period == 0) { if (save_period != 0 && (i + 1) % save_period == 0) {
this->SaveModel(i); this->SaveModel(i);
} }
utils::Assert(rabit::VersionNumber() == i, "incorrect version number");
// checkpoint the model
rabit::CheckPoint(&learner);
elapsed = (unsigned long)(time(NULL) - start); elapsed = (unsigned long)(time(NULL) - start);
} }
// always save final round // always save final round
@ -263,11 +285,6 @@ class BoostLearnTask {
} }
int main(int argc, char *argv[]){ int main(int argc, char *argv[]){
rabit::Init(argc, argv);
if (rabit::IsDistributed()) {
std::string pname = rabit::GetProcessorName();
printf("start %s:%d\n", pname.c_str(), rabit::GetRank());
}
xgboost::random::Seed(0); xgboost::random::Seed(0);
xgboost::BoostLearnTask tsk; xgboost::BoostLearnTask tsk;
int ret = tsk.Run(argc, argv); int ret = tsk.Run(argc, argv);