add rabit checkpoint to xgb
This commit is contained in:
parent
8e16cc4617
commit
deb21351b9
@ -32,10 +32,10 @@ class GBLinear : public IGradBooster {
|
|||||||
model.param.SetParam(name, val);
|
model.param.SetParam(name, val);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
virtual void LoadModel(utils::IStream &fi) {
|
virtual void LoadModel(utils::IStream &fi, bool with_pbuffer) {
|
||||||
model.LoadModel(fi);
|
model.LoadModel(fi);
|
||||||
}
|
}
|
||||||
virtual void SaveModel(utils::IStream &fo) const {
|
virtual void SaveModel(utils::IStream &fo, bool with_pbuffer) const {
|
||||||
model.SaveModel(fo);
|
model.SaveModel(fo);
|
||||||
}
|
}
|
||||||
virtual void InitModel(void) {
|
virtual void InitModel(void) {
|
||||||
|
|||||||
@ -27,13 +27,15 @@ class IGradBooster {
|
|||||||
/*!
|
/*!
|
||||||
* \brief load model from stream
|
* \brief load model from stream
|
||||||
* \param fi input stream
|
* \param fi input stream
|
||||||
|
* \param with_pbuffer whether the incoming data contains pbuffer
|
||||||
*/
|
*/
|
||||||
virtual void LoadModel(utils::IStream &fi) = 0;
|
virtual void LoadModel(utils::IStream &fi, bool with_pbuffer) = 0;
|
||||||
/*!
|
/*!
|
||||||
* \brief save model to stream
|
* \brief save model to stream
|
||||||
* \param fo output stream
|
* \param fo output stream
|
||||||
|
* \param with_pbuffer whether save out pbuffer
|
||||||
*/
|
*/
|
||||||
virtual void SaveModel(utils::IStream &fo) const = 0;
|
virtual void SaveModel(utils::IStream &fo, bool with_pbuffer) const = 0;
|
||||||
/*!
|
/*!
|
||||||
* \brief initialize the model
|
* \brief initialize the model
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -39,7 +39,7 @@ class GBTree : public IGradBooster {
|
|||||||
tparam.SetParam(name, val);
|
tparam.SetParam(name, val);
|
||||||
if (trees.size() == 0) mparam.SetParam(name, val);
|
if (trees.size() == 0) mparam.SetParam(name, val);
|
||||||
}
|
}
|
||||||
virtual void LoadModel(utils::IStream &fi) {
|
virtual void LoadModel(utils::IStream &fi, bool with_pbuffer) {
|
||||||
this->Clear();
|
this->Clear();
|
||||||
utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0,
|
utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0,
|
||||||
"GBTree: invalid model file");
|
"GBTree: invalid model file");
|
||||||
@ -56,13 +56,19 @@ class GBTree : public IGradBooster {
|
|||||||
if (mparam.num_pbuffer != 0) {
|
if (mparam.num_pbuffer != 0) {
|
||||||
pred_buffer.resize(mparam.PredBufferSize());
|
pred_buffer.resize(mparam.PredBufferSize());
|
||||||
pred_counter.resize(mparam.PredBufferSize());
|
pred_counter.resize(mparam.PredBufferSize());
|
||||||
utils::Check(fi.Read(&pred_buffer[0], pred_buffer.size() * sizeof(float)) != 0,
|
if (with_pbuffer) {
|
||||||
"GBTree: invalid model file");
|
utils::Check(fi.Read(&pred_buffer[0], pred_buffer.size() * sizeof(float)) != 0,
|
||||||
utils::Check(fi.Read(&pred_counter[0], pred_counter.size() * sizeof(unsigned)) != 0,
|
"GBTree: invalid model file");
|
||||||
"GBTree: invalid model file");
|
utils::Check(fi.Read(&pred_counter[0], pred_counter.size() * sizeof(unsigned)) != 0,
|
||||||
|
"GBTree: invalid model file");
|
||||||
|
} else {
|
||||||
|
// reset predict buffer if the input do not have them
|
||||||
|
std::fill(pred_buffer.begin(), pred_buffer.end(), 0.0f);
|
||||||
|
std::fill(pred_counter.begin(), pred_counter.end(), 0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
virtual void SaveModel(utils::IStream &fo) const {
|
virtual void SaveModel(utils::IStream &fo, bool with_pbuffer) const {
|
||||||
utils::Assert(mparam.num_trees == static_cast<int>(trees.size()), "GBTree");
|
utils::Assert(mparam.num_trees == static_cast<int>(trees.size()), "GBTree");
|
||||||
fo.Write(&mparam, sizeof(ModelParam));
|
fo.Write(&mparam, sizeof(ModelParam));
|
||||||
for (size_t i = 0; i < trees.size(); ++i) {
|
for (size_t i = 0; i < trees.size(); ++i) {
|
||||||
@ -71,7 +77,7 @@ class GBTree : public IGradBooster {
|
|||||||
if (tree_info.size() != 0) {
|
if (tree_info.size() != 0) {
|
||||||
fo.Write(&tree_info[0], sizeof(int) * tree_info.size());
|
fo.Write(&tree_info[0], sizeof(int) * tree_info.size());
|
||||||
}
|
}
|
||||||
if (mparam.num_pbuffer != 0) {
|
if (mparam.num_pbuffer != 0 && with_pbuffer) {
|
||||||
fo.Write(&pred_buffer[0], pred_buffer.size() * sizeof(float));
|
fo.Write(&pred_buffer[0], pred_buffer.size() * sizeof(float));
|
||||||
fo.Write(&pred_counter[0], pred_counter.size() * sizeof(unsigned));
|
fo.Write(&pred_counter[0], pred_counter.size() * sizeof(unsigned));
|
||||||
}
|
}
|
||||||
|
|||||||
@ -23,7 +23,7 @@ namespace learner {
|
|||||||
* \brief learner that takes do gradient boosting on specific objective functions
|
* \brief learner that takes do gradient boosting on specific objective functions
|
||||||
* and do training and prediction
|
* and do training and prediction
|
||||||
*/
|
*/
|
||||||
class BoostLearner {
|
class BoostLearner : public rabit::ISerializable {
|
||||||
public:
|
public:
|
||||||
BoostLearner(void) {
|
BoostLearner(void) {
|
||||||
obj_ = NULL;
|
obj_ = NULL;
|
||||||
@ -35,7 +35,7 @@ class BoostLearner {
|
|||||||
distributed_mode = 0;
|
distributed_mode = 0;
|
||||||
pred_buffer_size = 0;
|
pred_buffer_size = 0;
|
||||||
}
|
}
|
||||||
~BoostLearner(void) {
|
virtual ~BoostLearner(void) {
|
||||||
if (obj_ != NULL) delete obj_;
|
if (obj_ != NULL) delete obj_;
|
||||||
if (gbm_ != NULL) delete gbm_;
|
if (gbm_ != NULL) delete gbm_;
|
||||||
}
|
}
|
||||||
@ -140,9 +140,9 @@ class BoostLearner {
|
|||||||
/*!
|
/*!
|
||||||
* \brief load model from stream
|
* \brief load model from stream
|
||||||
* \param fi input stream
|
* \param fi input stream
|
||||||
* \param keep_predbuffer whether to keep predict buffer
|
* \param with_pbuffer whether to load with predict buffer
|
||||||
*/
|
*/
|
||||||
inline void LoadModel(utils::IStream &fi, bool keep_predbuffer = true) {
|
inline void LoadModel(utils::IStream &fi, bool with_pbuffer = true) {
|
||||||
utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0,
|
utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0,
|
||||||
"BoostLearner: wrong model format");
|
"BoostLearner: wrong model format");
|
||||||
utils::Check(fi.Read(&name_obj_), "BoostLearner: wrong model format");
|
utils::Check(fi.Read(&name_obj_), "BoostLearner: wrong model format");
|
||||||
@ -151,11 +151,23 @@ class BoostLearner {
|
|||||||
if (obj_ != NULL) delete obj_;
|
if (obj_ != NULL) delete obj_;
|
||||||
if (gbm_ != NULL) delete gbm_;
|
if (gbm_ != NULL) delete gbm_;
|
||||||
this->InitObjGBM();
|
this->InitObjGBM();
|
||||||
gbm_->LoadModel(fi);
|
gbm_->LoadModel(fi, with_pbuffer);
|
||||||
if (keep_predbuffer && distributed_mode == 2 && rabit::GetRank() != 0) {
|
if (with_pbuffer && distributed_mode == 2 && rabit::GetRank() != 0) {
|
||||||
gbm_->ResetPredBuffer(pred_buffer_size);
|
gbm_->ResetPredBuffer(pred_buffer_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// rabit load model from rabit checkpoint
|
||||||
|
virtual void Load(rabit::IStream &fi) {
|
||||||
|
RabitStreamAdapter fs(fi);
|
||||||
|
// for row split, we should not keep pbuffer
|
||||||
|
this->LoadModel(fs, distributed_mode != 2);
|
||||||
|
}
|
||||||
|
// rabit save model to rabit checkpoint
|
||||||
|
virtual void Save(rabit::IStream &fo) const {
|
||||||
|
RabitStreamAdapter fs(fo);
|
||||||
|
// for row split, we should not keep pbuffer
|
||||||
|
this->SaveModel(fs, distributed_mode != 2);
|
||||||
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief load model from file
|
* \brief load model from file
|
||||||
* \param fname file name
|
* \param fname file name
|
||||||
@ -165,11 +177,11 @@ class BoostLearner {
|
|||||||
this->LoadModel(fi);
|
this->LoadModel(fi);
|
||||||
fi.Close();
|
fi.Close();
|
||||||
}
|
}
|
||||||
inline void SaveModel(utils::IStream &fo) const {
|
inline void SaveModel(utils::IStream &fo, bool with_pbuffer = true) const {
|
||||||
fo.Write(&mparam, sizeof(ModelParam));
|
fo.Write(&mparam, sizeof(ModelParam));
|
||||||
fo.Write(name_obj_);
|
fo.Write(name_obj_);
|
||||||
fo.Write(name_gbm_);
|
fo.Write(name_gbm_);
|
||||||
gbm_->SaveModel(fo);
|
gbm_->SaveModel(fo, with_pbuffer);
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
* \brief save model into file
|
* \brief save model into file
|
||||||
@ -394,6 +406,23 @@ class BoostLearner {
|
|||||||
// data structure field
|
// data structure field
|
||||||
/*! \brief the entries indicates that we have internal prediction cache */
|
/*! \brief the entries indicates that we have internal prediction cache */
|
||||||
std::vector<CacheEntry> cache_;
|
std::vector<CacheEntry> cache_;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// adapt rabit stream to utils stream
|
||||||
|
struct RabitStreamAdapter : public utils::IStream {
|
||||||
|
// rabit stream
|
||||||
|
rabit::IStream &fs;
|
||||||
|
// constructr
|
||||||
|
RabitStreamAdapter(rabit::IStream &fs) : fs(fs) {}
|
||||||
|
// destructor
|
||||||
|
virtual ~RabitStreamAdapter(void){}
|
||||||
|
virtual size_t Read(void *ptr, size_t size) {
|
||||||
|
return fs.Read(ptr, size);
|
||||||
|
}
|
||||||
|
virtual void Write(const void *ptr, size_t size) {
|
||||||
|
fs.Write(ptr, size);
|
||||||
|
}
|
||||||
|
};
|
||||||
};
|
};
|
||||||
} // namespace learner
|
} // namespace learner
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -31,14 +31,32 @@ class BoostLearnTask {
|
|||||||
this->SetParam(name, val);
|
this->SetParam(name, val);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// whether need data rank
|
||||||
|
bool need_data_rank = strchr(train_path.c_str(), '%') != NULL;
|
||||||
|
// if need data rank in loading, initialize rabit engine before load data
|
||||||
|
// otherwise, initialize rabit engine after loading data
|
||||||
|
// lazy initialization of rabit engine can be helpful in speculative execution
|
||||||
|
if (need_data_rank) rabit::Init(argc, argv);
|
||||||
|
this->InitData();
|
||||||
|
if (!need_data_rank) rabit::Init(argc, argv);
|
||||||
|
if (rabit::IsDistributed()) {
|
||||||
|
std::string pname = rabit::GetProcessorName();
|
||||||
|
printf("start %s:%d\n", pname.c_str(), rabit::GetRank());
|
||||||
|
}
|
||||||
if (rabit::IsDistributed()) {
|
if (rabit::IsDistributed()) {
|
||||||
this->SetParam("data_split", "col");
|
this->SetParam("data_split", "col");
|
||||||
}
|
}
|
||||||
if (rabit::GetRank() != 0) {
|
if (rabit::GetRank() != 0) {
|
||||||
this->SetParam("silent", "2");
|
this->SetParam("silent", "2");
|
||||||
}
|
}
|
||||||
this->InitData();
|
|
||||||
this->InitLearner();
|
if (task == "train") {
|
||||||
|
// if task is training, will try recover from checkpoint
|
||||||
|
this->TaskTrain();
|
||||||
|
return 0;
|
||||||
|
} else {
|
||||||
|
this->InitLearner();
|
||||||
|
}
|
||||||
if (task == "dump") {
|
if (task == "dump") {
|
||||||
this->TaskDump(); return 0;
|
this->TaskDump(); return 0;
|
||||||
}
|
}
|
||||||
@ -47,8 +65,6 @@ class BoostLearnTask {
|
|||||||
}
|
}
|
||||||
if (task == "pred") {
|
if (task == "pred") {
|
||||||
this->TaskPred();
|
this->TaskPred();
|
||||||
} else {
|
|
||||||
this->TaskTrain();
|
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
@ -152,10 +168,13 @@ class BoostLearnTask {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
inline void TaskTrain(void) {
|
inline void TaskTrain(void) {
|
||||||
|
int version = rabit::LoadCheckPoint(&learner);
|
||||||
|
if (version == 0) this->InitLearner();
|
||||||
|
|
||||||
const time_t start = time(NULL);
|
const time_t start = time(NULL);
|
||||||
unsigned long elapsed = 0;
|
unsigned long elapsed = 0;
|
||||||
learner.CheckInit(data);
|
learner.CheckInit(data);
|
||||||
for (int i = 0; i < num_round; ++i) {
|
for (int i = version; i < num_round; ++i) {
|
||||||
elapsed = (unsigned long)(time(NULL) - start);
|
elapsed = (unsigned long)(time(NULL) - start);
|
||||||
if (!silent) printf("boosting round %d, %lu sec elapsed\n", i, elapsed);
|
if (!silent) printf("boosting round %d, %lu sec elapsed\n", i, elapsed);
|
||||||
learner.UpdateOneIter(i, *data);
|
learner.UpdateOneIter(i, *data);
|
||||||
@ -166,6 +185,9 @@ class BoostLearnTask {
|
|||||||
if (save_period != 0 && (i + 1) % save_period == 0) {
|
if (save_period != 0 && (i + 1) % save_period == 0) {
|
||||||
this->SaveModel(i);
|
this->SaveModel(i);
|
||||||
}
|
}
|
||||||
|
utils::Assert(rabit::VersionNumber() == i, "incorrect version number");
|
||||||
|
// checkpoint the model
|
||||||
|
rabit::CheckPoint(&learner);
|
||||||
elapsed = (unsigned long)(time(NULL) - start);
|
elapsed = (unsigned long)(time(NULL) - start);
|
||||||
}
|
}
|
||||||
// always save final round
|
// always save final round
|
||||||
@ -263,11 +285,6 @@ class BoostLearnTask {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int main(int argc, char *argv[]){
|
int main(int argc, char *argv[]){
|
||||||
rabit::Init(argc, argv);
|
|
||||||
if (rabit::IsDistributed()) {
|
|
||||||
std::string pname = rabit::GetProcessorName();
|
|
||||||
printf("start %s:%d\n", pname.c_str(), rabit::GetRank());
|
|
||||||
}
|
|
||||||
xgboost::random::Seed(0);
|
xgboost::random::Seed(0);
|
||||||
xgboost::BoostLearnTask tsk;
|
xgboost::BoostLearnTask tsk;
|
||||||
int ret = tsk.Run(argc, argv);
|
int ret = tsk.Run(argc, argv);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user