add rabit checkpoint to xgb
This commit is contained in:
parent
8e16cc4617
commit
deb21351b9
@ -32,10 +32,10 @@ class GBLinear : public IGradBooster {
|
||||
model.param.SetParam(name, val);
|
||||
}
|
||||
}
|
||||
virtual void LoadModel(utils::IStream &fi) {
|
||||
virtual void LoadModel(utils::IStream &fi, bool with_pbuffer) {
|
||||
model.LoadModel(fi);
|
||||
}
|
||||
virtual void SaveModel(utils::IStream &fo) const {
|
||||
virtual void SaveModel(utils::IStream &fo, bool with_pbuffer) const {
|
||||
model.SaveModel(fo);
|
||||
}
|
||||
virtual void InitModel(void) {
|
||||
|
||||
@ -27,13 +27,15 @@ class IGradBooster {
|
||||
/*!
|
||||
* \brief load model from stream
|
||||
* \param fi input stream
|
||||
* \param with_pbuffer whether the incoming data contains pbuffer
|
||||
*/
|
||||
virtual void LoadModel(utils::IStream &fi) = 0;
|
||||
virtual void LoadModel(utils::IStream &fi, bool with_pbuffer) = 0;
|
||||
/*!
|
||||
* \brief save model to stream
|
||||
* \param fo output stream
|
||||
* \param with_pbuffer whether save out pbuffer
|
||||
*/
|
||||
virtual void SaveModel(utils::IStream &fo) const = 0;
|
||||
virtual void SaveModel(utils::IStream &fo, bool with_pbuffer) const = 0;
|
||||
/*!
|
||||
* \brief initialize the model
|
||||
*/
|
||||
|
||||
@ -39,7 +39,7 @@ class GBTree : public IGradBooster {
|
||||
tparam.SetParam(name, val);
|
||||
if (trees.size() == 0) mparam.SetParam(name, val);
|
||||
}
|
||||
virtual void LoadModel(utils::IStream &fi) {
|
||||
virtual void LoadModel(utils::IStream &fi, bool with_pbuffer) {
|
||||
this->Clear();
|
||||
utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0,
|
||||
"GBTree: invalid model file");
|
||||
@ -56,13 +56,19 @@ class GBTree : public IGradBooster {
|
||||
if (mparam.num_pbuffer != 0) {
|
||||
pred_buffer.resize(mparam.PredBufferSize());
|
||||
pred_counter.resize(mparam.PredBufferSize());
|
||||
utils::Check(fi.Read(&pred_buffer[0], pred_buffer.size() * sizeof(float)) != 0,
|
||||
"GBTree: invalid model file");
|
||||
utils::Check(fi.Read(&pred_counter[0], pred_counter.size() * sizeof(unsigned)) != 0,
|
||||
"GBTree: invalid model file");
|
||||
if (with_pbuffer) {
|
||||
utils::Check(fi.Read(&pred_buffer[0], pred_buffer.size() * sizeof(float)) != 0,
|
||||
"GBTree: invalid model file");
|
||||
utils::Check(fi.Read(&pred_counter[0], pred_counter.size() * sizeof(unsigned)) != 0,
|
||||
"GBTree: invalid model file");
|
||||
} else {
|
||||
// reset predict buffer if the input do not have them
|
||||
std::fill(pred_buffer.begin(), pred_buffer.end(), 0.0f);
|
||||
std::fill(pred_counter.begin(), pred_counter.end(), 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
virtual void SaveModel(utils::IStream &fo) const {
|
||||
virtual void SaveModel(utils::IStream &fo, bool with_pbuffer) const {
|
||||
utils::Assert(mparam.num_trees == static_cast<int>(trees.size()), "GBTree");
|
||||
fo.Write(&mparam, sizeof(ModelParam));
|
||||
for (size_t i = 0; i < trees.size(); ++i) {
|
||||
@ -71,7 +77,7 @@ class GBTree : public IGradBooster {
|
||||
if (tree_info.size() != 0) {
|
||||
fo.Write(&tree_info[0], sizeof(int) * tree_info.size());
|
||||
}
|
||||
if (mparam.num_pbuffer != 0) {
|
||||
if (mparam.num_pbuffer != 0 && with_pbuffer) {
|
||||
fo.Write(&pred_buffer[0], pred_buffer.size() * sizeof(float));
|
||||
fo.Write(&pred_counter[0], pred_counter.size() * sizeof(unsigned));
|
||||
}
|
||||
|
||||
@ -23,7 +23,7 @@ namespace learner {
|
||||
* \brief learner that takes do gradient boosting on specific objective functions
|
||||
* and do training and prediction
|
||||
*/
|
||||
class BoostLearner {
|
||||
class BoostLearner : public rabit::ISerializable {
|
||||
public:
|
||||
BoostLearner(void) {
|
||||
obj_ = NULL;
|
||||
@ -35,7 +35,7 @@ class BoostLearner {
|
||||
distributed_mode = 0;
|
||||
pred_buffer_size = 0;
|
||||
}
|
||||
~BoostLearner(void) {
|
||||
virtual ~BoostLearner(void) {
|
||||
if (obj_ != NULL) delete obj_;
|
||||
if (gbm_ != NULL) delete gbm_;
|
||||
}
|
||||
@ -140,9 +140,9 @@ class BoostLearner {
|
||||
/*!
|
||||
* \brief load model from stream
|
||||
* \param fi input stream
|
||||
* \param keep_predbuffer whether to keep predict buffer
|
||||
* \param with_pbuffer whether to load with predict buffer
|
||||
*/
|
||||
inline void LoadModel(utils::IStream &fi, bool keep_predbuffer = true) {
|
||||
inline void LoadModel(utils::IStream &fi, bool with_pbuffer = true) {
|
||||
utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0,
|
||||
"BoostLearner: wrong model format");
|
||||
utils::Check(fi.Read(&name_obj_), "BoostLearner: wrong model format");
|
||||
@ -151,11 +151,23 @@ class BoostLearner {
|
||||
if (obj_ != NULL) delete obj_;
|
||||
if (gbm_ != NULL) delete gbm_;
|
||||
this->InitObjGBM();
|
||||
gbm_->LoadModel(fi);
|
||||
if (keep_predbuffer && distributed_mode == 2 && rabit::GetRank() != 0) {
|
||||
gbm_->LoadModel(fi, with_pbuffer);
|
||||
if (with_pbuffer && distributed_mode == 2 && rabit::GetRank() != 0) {
|
||||
gbm_->ResetPredBuffer(pred_buffer_size);
|
||||
}
|
||||
}
|
||||
// rabit load model from rabit checkpoint
|
||||
virtual void Load(rabit::IStream &fi) {
|
||||
RabitStreamAdapter fs(fi);
|
||||
// for row split, we should not keep pbuffer
|
||||
this->LoadModel(fs, distributed_mode != 2);
|
||||
}
|
||||
// rabit save model to rabit checkpoint
|
||||
virtual void Save(rabit::IStream &fo) const {
|
||||
RabitStreamAdapter fs(fo);
|
||||
// for row split, we should not keep pbuffer
|
||||
this->SaveModel(fs, distributed_mode != 2);
|
||||
}
|
||||
/*!
|
||||
* \brief load model from file
|
||||
* \param fname file name
|
||||
@ -165,11 +177,11 @@ class BoostLearner {
|
||||
this->LoadModel(fi);
|
||||
fi.Close();
|
||||
}
|
||||
inline void SaveModel(utils::IStream &fo) const {
|
||||
inline void SaveModel(utils::IStream &fo, bool with_pbuffer = true) const {
|
||||
fo.Write(&mparam, sizeof(ModelParam));
|
||||
fo.Write(name_obj_);
|
||||
fo.Write(name_gbm_);
|
||||
gbm_->SaveModel(fo);
|
||||
gbm_->SaveModel(fo, with_pbuffer);
|
||||
}
|
||||
/*!
|
||||
* \brief save model into file
|
||||
@ -394,6 +406,23 @@ class BoostLearner {
|
||||
// data structure field
|
||||
/*! \brief the entries indicates that we have internal prediction cache */
|
||||
std::vector<CacheEntry> cache_;
|
||||
|
||||
private:
|
||||
// adapt rabit stream to utils stream
|
||||
struct RabitStreamAdapter : public utils::IStream {
|
||||
// rabit stream
|
||||
rabit::IStream &fs;
|
||||
// constructr
|
||||
RabitStreamAdapter(rabit::IStream &fs) : fs(fs) {}
|
||||
// destructor
|
||||
virtual ~RabitStreamAdapter(void){}
|
||||
virtual size_t Read(void *ptr, size_t size) {
|
||||
return fs.Read(ptr, size);
|
||||
}
|
||||
virtual void Write(const void *ptr, size_t size) {
|
||||
fs.Write(ptr, size);
|
||||
}
|
||||
};
|
||||
};
|
||||
} // namespace learner
|
||||
} // namespace xgboost
|
||||
|
||||
@ -31,14 +31,32 @@ class BoostLearnTask {
|
||||
this->SetParam(name, val);
|
||||
}
|
||||
}
|
||||
// whether need data rank
|
||||
bool need_data_rank = strchr(train_path.c_str(), '%') != NULL;
|
||||
// if need data rank in loading, initialize rabit engine before load data
|
||||
// otherwise, initialize rabit engine after loading data
|
||||
// lazy initialization of rabit engine can be helpful in speculative execution
|
||||
if (need_data_rank) rabit::Init(argc, argv);
|
||||
this->InitData();
|
||||
if (!need_data_rank) rabit::Init(argc, argv);
|
||||
if (rabit::IsDistributed()) {
|
||||
std::string pname = rabit::GetProcessorName();
|
||||
printf("start %s:%d\n", pname.c_str(), rabit::GetRank());
|
||||
}
|
||||
if (rabit::IsDistributed()) {
|
||||
this->SetParam("data_split", "col");
|
||||
}
|
||||
if (rabit::GetRank() != 0) {
|
||||
this->SetParam("silent", "2");
|
||||
}
|
||||
this->InitData();
|
||||
this->InitLearner();
|
||||
|
||||
if (task == "train") {
|
||||
// if task is training, will try recover from checkpoint
|
||||
this->TaskTrain();
|
||||
return 0;
|
||||
} else {
|
||||
this->InitLearner();
|
||||
}
|
||||
if (task == "dump") {
|
||||
this->TaskDump(); return 0;
|
||||
}
|
||||
@ -47,8 +65,6 @@ class BoostLearnTask {
|
||||
}
|
||||
if (task == "pred") {
|
||||
this->TaskPred();
|
||||
} else {
|
||||
this->TaskTrain();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
@ -152,10 +168,13 @@ class BoostLearnTask {
|
||||
}
|
||||
}
|
||||
inline void TaskTrain(void) {
|
||||
int version = rabit::LoadCheckPoint(&learner);
|
||||
if (version == 0) this->InitLearner();
|
||||
|
||||
const time_t start = time(NULL);
|
||||
unsigned long elapsed = 0;
|
||||
learner.CheckInit(data);
|
||||
for (int i = 0; i < num_round; ++i) {
|
||||
for (int i = version; i < num_round; ++i) {
|
||||
elapsed = (unsigned long)(time(NULL) - start);
|
||||
if (!silent) printf("boosting round %d, %lu sec elapsed\n", i, elapsed);
|
||||
learner.UpdateOneIter(i, *data);
|
||||
@ -166,6 +185,9 @@ class BoostLearnTask {
|
||||
if (save_period != 0 && (i + 1) % save_period == 0) {
|
||||
this->SaveModel(i);
|
||||
}
|
||||
utils::Assert(rabit::VersionNumber() == i, "incorrect version number");
|
||||
// checkpoint the model
|
||||
rabit::CheckPoint(&learner);
|
||||
elapsed = (unsigned long)(time(NULL) - start);
|
||||
}
|
||||
// always save final round
|
||||
@ -263,11 +285,6 @@ class BoostLearnTask {
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[]){
|
||||
rabit::Init(argc, argv);
|
||||
if (rabit::IsDistributed()) {
|
||||
std::string pname = rabit::GetProcessorName();
|
||||
printf("start %s:%d\n", pname.c_str(), rabit::GetRank());
|
||||
}
|
||||
xgboost::random::Seed(0);
|
||||
xgboost::BoostLearnTask tsk;
|
||||
int ret = tsk.Run(argc, argv);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user