checkin continue training

This commit is contained in:
tqchen
2014-11-19 20:06:08 -08:00
parent 26e5eae6f2
commit 970dd58dc2
5 changed files with 27 additions and 4 deletions

View File

@@ -33,6 +33,7 @@ class BoostLearner {
prob_buffer_row = 1.0f;
part_load_col = 0;
distributed_mode = 0;
pred_buffer_size = 0;
}
~BoostLearner(void) {
if (obj_ != NULL) delete obj_;
@@ -76,6 +77,7 @@ class BoostLearner {
if (!silent) {
utils::Printf("buffer_size=%ld\n", static_cast<long>(buffer_size));
}
this->pred_buffer_size = buffer_size;
}
/*!
* \brief set parameters from outside
@@ -139,8 +141,9 @@ class BoostLearner {
/*!
* \brief load model from stream
* \param fi input stream
* \param keep_predbuffer whether to keep predict buffer
*/
inline void LoadModel(utils::IStream &fi) {
inline void LoadModel(utils::IStream &fi, bool keep_predbuffer = true) {
utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0,
"BoostLearner: wrong model format");
utils::Check(fi.Read(&name_obj_), "BoostLearner: wrong model format");
@@ -150,6 +153,9 @@ class BoostLearner {
if (gbm_ != NULL) delete gbm_;
this->InitObjGBM();
gbm_->LoadModel(fi);
if (keep_predbuffer && distributed_mode == 2 && sync::GetRank() != 0) {
gbm_->ResetPredBuffer(pred_buffer_size);
}
}
/*!
* \brief load model from file
@@ -370,12 +376,14 @@ class BoostLearner {
int distributed_mode;
// randomly load part of data
int part_load_col;
// cached size of predict buffer
size_t pred_buffer_size;
// maximum buffred row value
float prob_buffer_row;
// evaluation set
EvalSet evaluator_;
// model parameter
ModelParam mparam;
ModelParam mparam;
// gbm model that back everything
gbm::IGradBooster *gbm_;
// name of gbm model used for training