checkin continue training

This commit is contained in:
tqchen 2014-11-19 20:06:08 -08:00
parent 26e5eae6f2
commit 970dd58dc2
5 changed files with 27 additions and 4 deletions

View File

@ -17,5 +17,8 @@ cd -
python splitrows.py ../../demo/regression/machine.txt.train train-machine $k
# run xgboost mpi
mpirun -n $k ../../xgboost-mpi machine-row.conf dsplit=row
mpirun -n $k ../../xgboost-mpi machine-row.conf dsplit=row num_round=3
# run xgboost-mpi save model 0001, continue to run from existing model
mpirun -n $k ../../xgboost-mpi machine-row.conf dsplit=row num_round=1
mpirun -n $k ../../xgboost-mpi machine-row.conf dsplit=row num_round=2 model_in=0001.model

View File

@ -38,6 +38,12 @@ class IGradBooster {
* \brief initialize the model
*/
virtual void InitModel(void) = 0;
/*!
* \brief reset the predict buffer
* this will invalidate all the previous cached results
* and recalculate from scratch
*/
virtual void ResetPredBuffer(size_t num_pbuffer) {}
/*!
* \brief peform update to the model(boosting)
* \param p_fmat feature matrix that provide access to features

View File

@ -84,6 +84,12 @@ class GBTree : public IGradBooster {
utils::Assert(mparam.num_trees == 0, "GBTree: model already initialized");
utils::Assert(trees.size() == 0, "GBTree: model already initialized");
}
virtual void ResetPredBuffer(size_t num_pbuffer) {
mparam.num_pbuffer = static_cast<int64_t>(num_pbuffer);
pred_buffer.clear(); pred_counter.clear();
pred_buffer.resize(mparam.PredBufferSize(), 0.0f);
pred_counter.resize(mparam.PredBufferSize(), 0);
}
virtual void DoBoost(IFMatrix *p_fmat,
int64_t buffer_offset,
const BoosterInfo &info,

View File

@ -33,6 +33,7 @@ class BoostLearner {
prob_buffer_row = 1.0f;
part_load_col = 0;
distributed_mode = 0;
pred_buffer_size = 0;
}
~BoostLearner(void) {
if (obj_ != NULL) delete obj_;
@ -76,6 +77,7 @@ class BoostLearner {
if (!silent) {
utils::Printf("buffer_size=%ld\n", static_cast<long>(buffer_size));
}
this->pred_buffer_size = buffer_size;
}
/*!
* \brief set parameters from outside
@ -139,8 +141,9 @@ class BoostLearner {
/*!
* \brief load model from stream
* \param fi input stream
* \param keep_predbuffer whether to keep predict buffer
*/
inline void LoadModel(utils::IStream &fi) {
inline void LoadModel(utils::IStream &fi, bool keep_predbuffer = true) {
utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0,
"BoostLearner: wrong model format");
utils::Check(fi.Read(&name_obj_), "BoostLearner: wrong model format");
@ -150,6 +153,9 @@ class BoostLearner {
if (gbm_ != NULL) delete gbm_;
this->InitObjGBM();
gbm_->LoadModel(fi);
if (keep_predbuffer && distributed_mode == 2 && sync::GetRank() != 0) {
gbm_->ResetPredBuffer(pred_buffer_size);
}
}
/*!
* \brief load model from file
@ -370,6 +376,8 @@ class BoostLearner {
int distributed_mode;
// randomly load part of data
int part_load_col;
// cached size of predict buffer
size_t pred_buffer_size;
// maximum buffred row value
float prob_buffer_row;
// evaluation set

View File

@ -142,7 +142,7 @@ class BoostLearnTask {
}
}
inline void InitLearner(void) {
if (model_in != "NULL"){
if (model_in != "NULL") {
utils::FileStream fi(utils::FopenCheck(model_in.c_str(), "rb"));
learner.LoadModel(fi);
fi.Close();