add rabit checkpoint to xgb

This commit is contained in:
tqchen 2014-12-20 01:05:40 -08:00
parent 8e16cc4617
commit deb21351b9
5 changed files with 83 additions and 29 deletions

View File

@ -32,10 +32,10 @@ class GBLinear : public IGradBooster {
model.param.SetParam(name, val);
}
}
virtual void LoadModel(utils::IStream &fi) {
virtual void LoadModel(utils::IStream &fi, bool with_pbuffer) {
model.LoadModel(fi);
}
virtual void SaveModel(utils::IStream &fo) const {
virtual void SaveModel(utils::IStream &fo, bool with_pbuffer) const {
model.SaveModel(fo);
}
virtual void InitModel(void) {

View File

@ -27,13 +27,15 @@ class IGradBooster {
/*!
* \brief load model from stream
* \param fi input stream
* \param with_pbuffer whether the incoming data contains pbuffer
*/
virtual void LoadModel(utils::IStream &fi) = 0;
virtual void LoadModel(utils::IStream &fi, bool with_pbuffer) = 0;
/*!
* \brief save model to stream
* \param fo output stream
* \param with_pbuffer whether save out pbuffer
*/
virtual void SaveModel(utils::IStream &fo) const = 0;
virtual void SaveModel(utils::IStream &fo, bool with_pbuffer) const = 0;
/*!
* \brief initialize the model
*/

View File

@ -39,7 +39,7 @@ class GBTree : public IGradBooster {
tparam.SetParam(name, val);
if (trees.size() == 0) mparam.SetParam(name, val);
}
virtual void LoadModel(utils::IStream &fi) {
virtual void LoadModel(utils::IStream &fi, bool with_pbuffer) {
this->Clear();
utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0,
"GBTree: invalid model file");
@ -56,13 +56,19 @@ class GBTree : public IGradBooster {
if (mparam.num_pbuffer != 0) {
pred_buffer.resize(mparam.PredBufferSize());
pred_counter.resize(mparam.PredBufferSize());
utils::Check(fi.Read(&pred_buffer[0], pred_buffer.size() * sizeof(float)) != 0,
"GBTree: invalid model file");
utils::Check(fi.Read(&pred_counter[0], pred_counter.size() * sizeof(unsigned)) != 0,
"GBTree: invalid model file");
if (with_pbuffer) {
utils::Check(fi.Read(&pred_buffer[0], pred_buffer.size() * sizeof(float)) != 0,
"GBTree: invalid model file");
utils::Check(fi.Read(&pred_counter[0], pred_counter.size() * sizeof(unsigned)) != 0,
"GBTree: invalid model file");
} else {
// reset predict buffer if the input do not have them
std::fill(pred_buffer.begin(), pred_buffer.end(), 0.0f);
std::fill(pred_counter.begin(), pred_counter.end(), 0);
}
}
}
virtual void SaveModel(utils::IStream &fo) const {
virtual void SaveModel(utils::IStream &fo, bool with_pbuffer) const {
utils::Assert(mparam.num_trees == static_cast<int>(trees.size()), "GBTree");
fo.Write(&mparam, sizeof(ModelParam));
for (size_t i = 0; i < trees.size(); ++i) {
@ -71,7 +77,7 @@ class GBTree : public IGradBooster {
if (tree_info.size() != 0) {
fo.Write(&tree_info[0], sizeof(int) * tree_info.size());
}
if (mparam.num_pbuffer != 0) {
if (mparam.num_pbuffer != 0 && with_pbuffer) {
fo.Write(&pred_buffer[0], pred_buffer.size() * sizeof(float));
fo.Write(&pred_counter[0], pred_counter.size() * sizeof(unsigned));
}

View File

@ -23,7 +23,7 @@ namespace learner {
* \brief learner that takes do gradient boosting on specific objective functions
* and do training and prediction
*/
class BoostLearner {
class BoostLearner : public rabit::ISerializable {
public:
BoostLearner(void) {
obj_ = NULL;
@ -35,7 +35,7 @@ class BoostLearner {
distributed_mode = 0;
pred_buffer_size = 0;
}
~BoostLearner(void) {
virtual ~BoostLearner(void) {
if (obj_ != NULL) delete obj_;
if (gbm_ != NULL) delete gbm_;
}
@ -140,9 +140,9 @@ class BoostLearner {
/*!
* \brief load model from stream
* \param fi input stream
* \param keep_predbuffer whether to keep predict buffer
* \param with_pbuffer whether to load with predict buffer
*/
inline void LoadModel(utils::IStream &fi, bool keep_predbuffer = true) {
inline void LoadModel(utils::IStream &fi, bool with_pbuffer = true) {
utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0,
"BoostLearner: wrong model format");
utils::Check(fi.Read(&name_obj_), "BoostLearner: wrong model format");
@ -151,11 +151,23 @@ class BoostLearner {
if (obj_ != NULL) delete obj_;
if (gbm_ != NULL) delete gbm_;
this->InitObjGBM();
gbm_->LoadModel(fi);
if (keep_predbuffer && distributed_mode == 2 && rabit::GetRank() != 0) {
gbm_->LoadModel(fi, with_pbuffer);
if (with_pbuffer && distributed_mode == 2 && rabit::GetRank() != 0) {
gbm_->ResetPredBuffer(pred_buffer_size);
}
}
// rabit load model from rabit checkpoint
virtual void Load(rabit::IStream &fi) {
RabitStreamAdapter fs(fi);
// for row split, we should not keep pbuffer
this->LoadModel(fs, distributed_mode != 2);
}
// rabit save model to rabit checkpoint
virtual void Save(rabit::IStream &fo) const {
RabitStreamAdapter fs(fo);
// for row split, we should not keep pbuffer
this->SaveModel(fs, distributed_mode != 2);
}
/*!
* \brief load model from file
* \param fname file name
@ -165,11 +177,11 @@ class BoostLearner {
this->LoadModel(fi);
fi.Close();
}
inline void SaveModel(utils::IStream &fo) const {
inline void SaveModel(utils::IStream &fo, bool with_pbuffer = true) const {
fo.Write(&mparam, sizeof(ModelParam));
fo.Write(name_obj_);
fo.Write(name_gbm_);
gbm_->SaveModel(fo);
gbm_->SaveModel(fo, with_pbuffer);
}
/*!
* \brief save model into file
@ -394,6 +406,23 @@ class BoostLearner {
// data structure field
/*! \brief the entries indicates that we have internal prediction cache */
std::vector<CacheEntry> cache_;
private:
// adapt rabit stream to utils stream
struct RabitStreamAdapter : public utils::IStream {
// rabit stream
rabit::IStream &fs;
// constructr
RabitStreamAdapter(rabit::IStream &fs) : fs(fs) {}
// destructor
virtual ~RabitStreamAdapter(void){}
virtual size_t Read(void *ptr, size_t size) {
return fs.Read(ptr, size);
}
virtual void Write(const void *ptr, size_t size) {
fs.Write(ptr, size);
}
};
};
} // namespace learner
} // namespace xgboost

View File

@ -31,14 +31,32 @@ class BoostLearnTask {
this->SetParam(name, val);
}
}
// whether need data rank
bool need_data_rank = strchr(train_path.c_str(), '%') != NULL;
// if need data rank in loading, initialize rabit engine before load data
// otherwise, initialize rabit engine after loading data
// lazy initialization of rabit engine can be helpful in speculative execution
if (need_data_rank) rabit::Init(argc, argv);
this->InitData();
if (!need_data_rank) rabit::Init(argc, argv);
if (rabit::IsDistributed()) {
std::string pname = rabit::GetProcessorName();
printf("start %s:%d\n", pname.c_str(), rabit::GetRank());
}
if (rabit::IsDistributed()) {
this->SetParam("data_split", "col");
}
if (rabit::GetRank() != 0) {
this->SetParam("silent", "2");
}
this->InitData();
this->InitLearner();
if (task == "train") {
// if task is training, will try recover from checkpoint
this->TaskTrain();
return 0;
} else {
this->InitLearner();
}
if (task == "dump") {
this->TaskDump(); return 0;
}
@ -47,8 +65,6 @@ class BoostLearnTask {
}
if (task == "pred") {
this->TaskPred();
} else {
this->TaskTrain();
}
return 0;
}
@ -152,10 +168,13 @@ class BoostLearnTask {
}
}
inline void TaskTrain(void) {
int version = rabit::LoadCheckPoint(&learner);
if (version == 0) this->InitLearner();
const time_t start = time(NULL);
unsigned long elapsed = 0;
learner.CheckInit(data);
for (int i = 0; i < num_round; ++i) {
for (int i = version; i < num_round; ++i) {
elapsed = (unsigned long)(time(NULL) - start);
if (!silent) printf("boosting round %d, %lu sec elapsed\n", i, elapsed);
learner.UpdateOneIter(i, *data);
@ -166,6 +185,9 @@ class BoostLearnTask {
if (save_period != 0 && (i + 1) % save_period == 0) {
this->SaveModel(i);
}
utils::Assert(rabit::VersionNumber() == i, "incorrect version number");
// checkpoint the model
rabit::CheckPoint(&learner);
elapsed = (unsigned long)(time(NULL) - start);
}
// always save final round
@ -263,11 +285,6 @@ class BoostLearnTask {
}
int main(int argc, char *argv[]){
rabit::Init(argc, argv);
if (rabit::IsDistributed()) {
std::string pname = rabit::GetProcessorName();
printf("start %s:%d\n", pname.c_str(), rabit::GetRank());
}
xgboost::random::Seed(0);
xgboost::BoostLearnTask tsk;
int ret = tsk.Run(argc, argv);