add rabit checkpoint to xgb

This commit is contained in:
tqchen
2014-12-20 01:05:40 -08:00
parent 8e16cc4617
commit deb21351b9
5 changed files with 83 additions and 29 deletions

View File

@@ -32,10 +32,10 @@ class GBLinear : public IGradBooster {
model.param.SetParam(name, val);
}
}
virtual void LoadModel(utils::IStream &fi) {
virtual void LoadModel(utils::IStream &fi, bool with_pbuffer) {
model.LoadModel(fi);
}
virtual void SaveModel(utils::IStream &fo) const {
virtual void SaveModel(utils::IStream &fo, bool with_pbuffer) const {
model.SaveModel(fo);
}
virtual void InitModel(void) {

View File

@@ -27,13 +27,15 @@ class IGradBooster {
/*!
* \brief load model from stream
* \param fi input stream
* \param with_pbuffer whether the incoming data contains pbuffer
*/
virtual void LoadModel(utils::IStream &fi) = 0;
virtual void LoadModel(utils::IStream &fi, bool with_pbuffer) = 0;
/*!
* \brief save model to stream
* \param fo output stream
* \param with_pbuffer whether save out pbuffer
*/
virtual void SaveModel(utils::IStream &fo) const = 0;
virtual void SaveModel(utils::IStream &fo, bool with_pbuffer) const = 0;
/*!
* \brief initialize the model
*/

View File

@@ -39,7 +39,7 @@ class GBTree : public IGradBooster {
tparam.SetParam(name, val);
if (trees.size() == 0) mparam.SetParam(name, val);
}
virtual void LoadModel(utils::IStream &fi) {
virtual void LoadModel(utils::IStream &fi, bool with_pbuffer) {
this->Clear();
utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0,
"GBTree: invalid model file");
@@ -56,13 +56,19 @@ class GBTree : public IGradBooster {
if (mparam.num_pbuffer != 0) {
pred_buffer.resize(mparam.PredBufferSize());
pred_counter.resize(mparam.PredBufferSize());
utils::Check(fi.Read(&pred_buffer[0], pred_buffer.size() * sizeof(float)) != 0,
"GBTree: invalid model file");
utils::Check(fi.Read(&pred_counter[0], pred_counter.size() * sizeof(unsigned)) != 0,
"GBTree: invalid model file");
if (with_pbuffer) {
utils::Check(fi.Read(&pred_buffer[0], pred_buffer.size() * sizeof(float)) != 0,
"GBTree: invalid model file");
utils::Check(fi.Read(&pred_counter[0], pred_counter.size() * sizeof(unsigned)) != 0,
"GBTree: invalid model file");
} else {
// reset predict buffer if the input do not have them
std::fill(pred_buffer.begin(), pred_buffer.end(), 0.0f);
std::fill(pred_counter.begin(), pred_counter.end(), 0);
}
}
}
virtual void SaveModel(utils::IStream &fo) const {
virtual void SaveModel(utils::IStream &fo, bool with_pbuffer) const {
utils::Assert(mparam.num_trees == static_cast<int>(trees.size()), "GBTree");
fo.Write(&mparam, sizeof(ModelParam));
for (size_t i = 0; i < trees.size(); ++i) {
@@ -71,7 +77,7 @@ class GBTree : public IGradBooster {
if (tree_info.size() != 0) {
fo.Write(&tree_info[0], sizeof(int) * tree_info.size());
}
if (mparam.num_pbuffer != 0) {
if (mparam.num_pbuffer != 0 && with_pbuffer) {
fo.Write(&pred_buffer[0], pred_buffer.size() * sizeof(float));
fo.Write(&pred_counter[0], pred_counter.size() * sizeof(unsigned));
}