checkin continue training
This commit is contained in:
@@ -33,6 +33,7 @@ class BoostLearner {
|
||||
prob_buffer_row = 1.0f;
|
||||
part_load_col = 0;
|
||||
distributed_mode = 0;
|
||||
pred_buffer_size = 0;
|
||||
}
|
||||
~BoostLearner(void) {
|
||||
if (obj_ != NULL) delete obj_;
|
||||
@@ -76,6 +77,7 @@ class BoostLearner {
|
||||
if (!silent) {
|
||||
utils::Printf("buffer_size=%ld\n", static_cast<long>(buffer_size));
|
||||
}
|
||||
this->pred_buffer_size = buffer_size;
|
||||
}
|
||||
/*!
|
||||
* \brief set parameters from outside
|
||||
@@ -139,8 +141,9 @@ class BoostLearner {
|
||||
/*!
|
||||
* \brief load model from stream
|
||||
* \param fi input stream
|
||||
* \param keep_predbuffer whether to keep predict buffer
|
||||
*/
|
||||
inline void LoadModel(utils::IStream &fi) {
|
||||
inline void LoadModel(utils::IStream &fi, bool keep_predbuffer = true) {
|
||||
utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0,
|
||||
"BoostLearner: wrong model format");
|
||||
utils::Check(fi.Read(&name_obj_), "BoostLearner: wrong model format");
|
||||
@@ -150,6 +153,9 @@ class BoostLearner {
|
||||
if (gbm_ != NULL) delete gbm_;
|
||||
this->InitObjGBM();
|
||||
gbm_->LoadModel(fi);
|
||||
if (keep_predbuffer && distributed_mode == 2 && sync::GetRank() != 0) {
|
||||
gbm_->ResetPredBuffer(pred_buffer_size);
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief load model from file
|
||||
@@ -370,12 +376,14 @@ class BoostLearner {
|
||||
int distributed_mode;
|
||||
// randomly load part of data
|
||||
int part_load_col;
|
||||
// cached size of predict buffer
|
||||
size_t pred_buffer_size;
|
||||
// maximum buffred row value
|
||||
float prob_buffer_row;
|
||||
// evaluation set
|
||||
EvalSet evaluator_;
|
||||
// model parameter
|
||||
ModelParam mparam;
|
||||
ModelParam mparam;
|
||||
// gbm model that back everything
|
||||
gbm::IGradBooster *gbm_;
|
||||
// name of gbm model used for training
|
||||
|
||||
Reference in New Issue
Block a user