add rabit checkpoint to xgb

This commit is contained in:
tqchen
2014-12-20 01:05:40 -08:00
parent 8e16cc4617
commit deb21351b9
5 changed files with 83 additions and 29 deletions

View File

@@ -23,7 +23,7 @@ namespace learner {
* \brief learner that takes do gradient boosting on specific objective functions
* and do training and prediction
*/
class BoostLearner {
class BoostLearner : public rabit::ISerializable {
public:
BoostLearner(void) {
obj_ = NULL;
@@ -35,7 +35,7 @@ class BoostLearner {
distributed_mode = 0;
pred_buffer_size = 0;
}
~BoostLearner(void) {
virtual ~BoostLearner(void) {
if (obj_ != NULL) delete obj_;
if (gbm_ != NULL) delete gbm_;
}
@@ -140,9 +140,9 @@ class BoostLearner {
/*!
* \brief load model from stream
* \param fi input stream
* \param keep_predbuffer whether to keep predict buffer
* \param with_pbuffer whether to load with predict buffer
*/
inline void LoadModel(utils::IStream &fi, bool keep_predbuffer = true) {
inline void LoadModel(utils::IStream &fi, bool with_pbuffer = true) {
utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0,
"BoostLearner: wrong model format");
utils::Check(fi.Read(&name_obj_), "BoostLearner: wrong model format");
@@ -151,11 +151,23 @@ class BoostLearner {
if (obj_ != NULL) delete obj_;
if (gbm_ != NULL) delete gbm_;
this->InitObjGBM();
gbm_->LoadModel(fi);
if (keep_predbuffer && distributed_mode == 2 && rabit::GetRank() != 0) {
gbm_->LoadModel(fi, with_pbuffer);
if (with_pbuffer && distributed_mode == 2 && rabit::GetRank() != 0) {
gbm_->ResetPredBuffer(pred_buffer_size);
}
}
// rabit load model from rabit checkpoint
virtual void Load(rabit::IStream &fi) {
RabitStreamAdapter fs(fi);
// for row split, we should not keep pbuffer
this->LoadModel(fs, distributed_mode != 2);
}
// rabit save model to rabit checkpoint
virtual void Save(rabit::IStream &fo) const {
RabitStreamAdapter fs(fo);
// for row split, we should not keep pbuffer
this->SaveModel(fs, distributed_mode != 2);
}
/*!
* \brief load model from file
* \param fname file name
@@ -165,11 +177,11 @@ class BoostLearner {
this->LoadModel(fi);
fi.Close();
}
inline void SaveModel(utils::IStream &fo) const {
inline void SaveModel(utils::IStream &fo, bool with_pbuffer = true) const {
fo.Write(&mparam, sizeof(ModelParam));
fo.Write(name_obj_);
fo.Write(name_gbm_);
gbm_->SaveModel(fo);
gbm_->SaveModel(fo, with_pbuffer);
}
/*!
* \brief save model into file
@@ -394,6 +406,23 @@ class BoostLearner {
// data structure field
/*! \brief the entries indicates that we have internal prediction cache */
std::vector<CacheEntry> cache_;
private:
// adapt rabit stream to utils stream
struct RabitStreamAdapter : public utils::IStream {
// rabit stream
rabit::IStream &fs;
// constructr
RabitStreamAdapter(rabit::IStream &fs) : fs(fs) {}
// destructor
virtual ~RabitStreamAdapter(void){}
virtual size_t Read(void *ptr, size_t size) {
return fs.Read(ptr, size);
}
virtual void Write(const void *ptr, size_t size) {
fs.Write(ptr, size);
}
};
};
} // namespace learner
} // namespace xgboost