add rabit checkpoint to xgb
This commit is contained in:
@@ -32,10 +32,10 @@ class GBLinear : public IGradBooster {
|
||||
model.param.SetParam(name, val);
|
||||
}
|
||||
}
|
||||
virtual void LoadModel(utils::IStream &fi) {
|
||||
virtual void LoadModel(utils::IStream &fi, bool with_pbuffer) {
|
||||
model.LoadModel(fi);
|
||||
}
|
||||
virtual void SaveModel(utils::IStream &fo) const {
|
||||
virtual void SaveModel(utils::IStream &fo, bool with_pbuffer) const {
|
||||
model.SaveModel(fo);
|
||||
}
|
||||
virtual void InitModel(void) {
|
||||
|
||||
@@ -27,13 +27,15 @@ class IGradBooster {
|
||||
/*!
|
||||
* \brief load model from stream
|
||||
* \param fi input stream
|
||||
* \param with_pbuffer whether the incoming data contains pbuffer
|
||||
*/
|
||||
virtual void LoadModel(utils::IStream &fi) = 0;
|
||||
virtual void LoadModel(utils::IStream &fi, bool with_pbuffer) = 0;
|
||||
/*!
|
||||
* \brief save model to stream
|
||||
* \param fo output stream
|
||||
* \param with_pbuffer whether save out pbuffer
|
||||
*/
|
||||
virtual void SaveModel(utils::IStream &fo) const = 0;
|
||||
virtual void SaveModel(utils::IStream &fo, bool with_pbuffer) const = 0;
|
||||
/*!
|
||||
* \brief initialize the model
|
||||
*/
|
||||
|
||||
@@ -39,7 +39,7 @@ class GBTree : public IGradBooster {
|
||||
tparam.SetParam(name, val);
|
||||
if (trees.size() == 0) mparam.SetParam(name, val);
|
||||
}
|
||||
virtual void LoadModel(utils::IStream &fi) {
|
||||
virtual void LoadModel(utils::IStream &fi, bool with_pbuffer) {
|
||||
this->Clear();
|
||||
utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0,
|
||||
"GBTree: invalid model file");
|
||||
@@ -56,13 +56,19 @@ class GBTree : public IGradBooster {
|
||||
if (mparam.num_pbuffer != 0) {
|
||||
pred_buffer.resize(mparam.PredBufferSize());
|
||||
pred_counter.resize(mparam.PredBufferSize());
|
||||
utils::Check(fi.Read(&pred_buffer[0], pred_buffer.size() * sizeof(float)) != 0,
|
||||
"GBTree: invalid model file");
|
||||
utils::Check(fi.Read(&pred_counter[0], pred_counter.size() * sizeof(unsigned)) != 0,
|
||||
"GBTree: invalid model file");
|
||||
if (with_pbuffer) {
|
||||
utils::Check(fi.Read(&pred_buffer[0], pred_buffer.size() * sizeof(float)) != 0,
|
||||
"GBTree: invalid model file");
|
||||
utils::Check(fi.Read(&pred_counter[0], pred_counter.size() * sizeof(unsigned)) != 0,
|
||||
"GBTree: invalid model file");
|
||||
} else {
|
||||
// reset predict buffer if the input do not have them
|
||||
std::fill(pred_buffer.begin(), pred_buffer.end(), 0.0f);
|
||||
std::fill(pred_counter.begin(), pred_counter.end(), 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
virtual void SaveModel(utils::IStream &fo) const {
|
||||
virtual void SaveModel(utils::IStream &fo, bool with_pbuffer) const {
|
||||
utils::Assert(mparam.num_trees == static_cast<int>(trees.size()), "GBTree");
|
||||
fo.Write(&mparam, sizeof(ModelParam));
|
||||
for (size_t i = 0; i < trees.size(); ++i) {
|
||||
@@ -71,7 +77,7 @@ class GBTree : public IGradBooster {
|
||||
if (tree_info.size() != 0) {
|
||||
fo.Write(&tree_info[0], sizeof(int) * tree_info.size());
|
||||
}
|
||||
if (mparam.num_pbuffer != 0) {
|
||||
if (mparam.num_pbuffer != 0 && with_pbuffer) {
|
||||
fo.Write(&pred_buffer[0], pred_buffer.size() * sizeof(float));
|
||||
fo.Write(&pred_counter[0], pred_counter.size() * sizeof(unsigned));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user