add rabit checkpoint to xgb
This commit is contained in:
@@ -23,7 +23,7 @@ namespace learner {
|
||||
* \brief learner that takes do gradient boosting on specific objective functions
|
||||
* and do training and prediction
|
||||
*/
|
||||
class BoostLearner {
|
||||
class BoostLearner : public rabit::ISerializable {
|
||||
public:
|
||||
BoostLearner(void) {
|
||||
obj_ = NULL;
|
||||
@@ -35,7 +35,7 @@ class BoostLearner {
|
||||
distributed_mode = 0;
|
||||
pred_buffer_size = 0;
|
||||
}
|
||||
~BoostLearner(void) {
|
||||
virtual ~BoostLearner(void) {
|
||||
if (obj_ != NULL) delete obj_;
|
||||
if (gbm_ != NULL) delete gbm_;
|
||||
}
|
||||
@@ -140,9 +140,9 @@ class BoostLearner {
|
||||
/*!
|
||||
* \brief load model from stream
|
||||
* \param fi input stream
|
||||
* \param keep_predbuffer whether to keep predict buffer
|
||||
* \param with_pbuffer whether to load with predict buffer
|
||||
*/
|
||||
inline void LoadModel(utils::IStream &fi, bool keep_predbuffer = true) {
|
||||
inline void LoadModel(utils::IStream &fi, bool with_pbuffer = true) {
|
||||
utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0,
|
||||
"BoostLearner: wrong model format");
|
||||
utils::Check(fi.Read(&name_obj_), "BoostLearner: wrong model format");
|
||||
@@ -151,11 +151,23 @@ class BoostLearner {
|
||||
if (obj_ != NULL) delete obj_;
|
||||
if (gbm_ != NULL) delete gbm_;
|
||||
this->InitObjGBM();
|
||||
gbm_->LoadModel(fi);
|
||||
if (keep_predbuffer && distributed_mode == 2 && rabit::GetRank() != 0) {
|
||||
gbm_->LoadModel(fi, with_pbuffer);
|
||||
if (with_pbuffer && distributed_mode == 2 && rabit::GetRank() != 0) {
|
||||
gbm_->ResetPredBuffer(pred_buffer_size);
|
||||
}
|
||||
}
|
||||
// rabit load model from rabit checkpoint
|
||||
virtual void Load(rabit::IStream &fi) {
|
||||
RabitStreamAdapter fs(fi);
|
||||
// for row split, we should not keep pbuffer
|
||||
this->LoadModel(fs, distributed_mode != 2);
|
||||
}
|
||||
// rabit save model to rabit checkpoint
|
||||
virtual void Save(rabit::IStream &fo) const {
|
||||
RabitStreamAdapter fs(fo);
|
||||
// for row split, we should not keep pbuffer
|
||||
this->SaveModel(fs, distributed_mode != 2);
|
||||
}
|
||||
/*!
|
||||
* \brief load model from file
|
||||
* \param fname file name
|
||||
@@ -165,11 +177,11 @@ class BoostLearner {
|
||||
this->LoadModel(fi);
|
||||
fi.Close();
|
||||
}
|
||||
inline void SaveModel(utils::IStream &fo) const {
|
||||
inline void SaveModel(utils::IStream &fo, bool with_pbuffer = true) const {
|
||||
fo.Write(&mparam, sizeof(ModelParam));
|
||||
fo.Write(name_obj_);
|
||||
fo.Write(name_gbm_);
|
||||
gbm_->SaveModel(fo);
|
||||
gbm_->SaveModel(fo, with_pbuffer);
|
||||
}
|
||||
/*!
|
||||
* \brief save model into file
|
||||
@@ -394,6 +406,23 @@ class BoostLearner {
|
||||
// data structure field
|
||||
/*! \brief the entries indicates that we have internal prediction cache */
|
||||
std::vector<CacheEntry> cache_;
|
||||
|
||||
private:
|
||||
// adapt rabit stream to utils stream
|
||||
struct RabitStreamAdapter : public utils::IStream {
|
||||
// rabit stream
|
||||
rabit::IStream &fs;
|
||||
// constructr
|
||||
RabitStreamAdapter(rabit::IStream &fs) : fs(fs) {}
|
||||
// destructor
|
||||
virtual ~RabitStreamAdapter(void){}
|
||||
virtual size_t Read(void *ptr, size_t size) {
|
||||
return fs.Read(ptr, size);
|
||||
}
|
||||
virtual void Write(const void *ptr, size_t size) {
|
||||
fs.Write(ptr, size);
|
||||
}
|
||||
};
|
||||
};
|
||||
} // namespace learner
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user