fix the row split recovery, add per iteration random number seed

This commit is contained in:
tqchen 2014-12-21 17:31:42 -08:00
parent eff5c6baa8
commit 677475529f
6 changed files with 34 additions and 23 deletions

View File

@ -13,7 +13,7 @@ endif
# by default use c++11 # by default use c++11
ifeq ($(no_cxx11),1) ifeq ($(no_cxx11),1)
else else
CFLAGS += -std=c++11 CFLAGS +=
endif endif
# specify tensor path # specify tensor path
@ -30,7 +30,7 @@ mpi: $(MPIBIN)
# rules to get rabit library # rules to get rabit library
librabit: librabit:
if [ ! -d rabit ]; then git clone https://github.com/tqchen/rabit.git; fi if [ ! -d rabit ]; then git clone https://github.com/tqchen/rabit.git; fi
cd rabit;make lib/librabit.a; cd - cd rabit;make lib/librabit.a lib/librabit_mock.a; cd -
librabit_mpi: librabit_mpi:
if [ ! -d rabit ]; then git clone https://github.com/tqchen/rabit.git; fi if [ ! -d rabit ]; then git clone https://github.com/tqchen/rabit.git; fi
cd rabit;make lib/librabit_mpi.a; cd - cd rabit;make lib/librabit_mpi.a; cd -

View File

@ -16,7 +16,7 @@ k=$1
python splitsvm.py ../../demo/data/agaricus.txt.train train $k python splitsvm.py ../../demo/data/agaricus.txt.train train $k
# run xgboost mpi # run xgboost mpi
../../rabit/tracker/rabit_mpi.py $k local ../../rabit/test/keepalive.sh ../../xgboost mushroom-col.conf dsplit=col mock=0,0,1,0 mock=1,1,0,0 ../../rabit/tracker/rabit_mpi.py $k local ../../rabit/test/keepalive.sh ../../xgboost mushroom-col.conf dsplit=col mock=0,1,0,0 mock=1,1,0,0
# the model can be directly loaded by single machine xgboost solver, as usuall # the model can be directly loaded by single machine xgboost solver, as usuall
#../../xgboost mushroom-col.conf task=dump model_in=0002.model fmap=../../demo/data/featmap.txt name_dump=dump.nice.$k.txt #../../xgboost mushroom-col.conf task=dump model_in=0002.model fmap=../../demo/data/featmap.txt name_dump=dump.nice.$k.txt

View File

@ -34,6 +34,8 @@ class BoostLearner : public rabit::ISerializable {
prob_buffer_row = 1.0f; prob_buffer_row = 1.0f;
distributed_mode = 0; distributed_mode = 0;
pred_buffer_size = 0; pred_buffer_size = 0;
seed_per_iteration = 0;
seed = 0;
} }
virtual ~BoostLearner(void) { virtual ~BoostLearner(void) {
if (obj_ != NULL) delete obj_; if (obj_ != NULL) delete obj_;
@ -102,7 +104,10 @@ class BoostLearner : public rabit::ISerializable {
this->SetParam("updater", "grow_colmaker,refresh,prune"); this->SetParam("updater", "grow_colmaker,refresh,prune");
} }
if (!strcmp(name, "eval_metric")) evaluator_.AddEval(val); if (!strcmp(name, "eval_metric")) evaluator_.AddEval(val);
if (!strcmp("seed", name)) random::Seed(atoi(val)); if (!strcmp("seed", name)) {
this->seed = seed; random::Seed(atoi(val));
}
if (!strcmp("seed_per_iter", name)) seed_per_iteration = atoi(val);
if (!strcmp(name, "num_class")) this->SetParam("num_output_group", val); if (!strcmp(name, "num_class")) this->SetParam("num_output_group", val);
if (!strcmp(name, "nthread")) { if (!strcmp(name, "nthread")) {
omp_set_num_threads(atoi(val)); omp_set_num_threads(atoi(val));
@ -222,6 +227,9 @@ class BoostLearner : public rabit::ISerializable {
* \param p_train pointer to the data matrix * \param p_train pointer to the data matrix
*/ */
inline void UpdateOneIter(int iter, const DMatrix &train) { inline void UpdateOneIter(int iter, const DMatrix &train) {
if (seed_per_iteration || rabit::IsDistributed()) {
random::Seed(this->seed * kRandSeedMagic);
}
this->PredictRaw(train, &preds_); this->PredictRaw(train, &preds_);
obj_->GetGradient(preds_, train.info, iter, &gpair_); obj_->GetGradient(preds_, train.info, iter, &gpair_);
gbm_->DoBoost(train.fmat(), this->FindBufferOffset(train), train.info.info, &gpair_); gbm_->DoBoost(train.fmat(), this->FindBufferOffset(train), train.info.info, &gpair_);
@ -369,6 +377,12 @@ class BoostLearner : public rabit::ISerializable {
} }
}; };
// data fields // data fields
// stored random seed
int seed;
// whether seed the PRNG each iteration
// this is important for restart from existing iterations
// default set to no, but will auto switch on in distributed mode
int seed_per_iteration;
// silent during training // silent during training
int silent; int silent;
// distributed learning mode, if any, 0:none, 1:col, 2:row // distributed learning mode, if any, 0:none, 1:col, 2:row
@ -397,6 +411,8 @@ class BoostLearner : public rabit::ISerializable {
std::vector<bst_gpair> gpair_; std::vector<bst_gpair> gpair_;
protected: protected:
// magic number to transform random seed
const static int kRandSeedMagic = 127;
// cache entry object that helps handle feature caching // cache entry object that helps handle feature caching
struct CacheEntry { struct CacheEntry {
const DMatrix *mat_; const DMatrix *mat_;

View File

@ -76,19 +76,15 @@ class BaseMaker: public IUpdater {
unsigned n = static_cast<unsigned>(p * findex.size()); unsigned n = static_cast<unsigned>(p * findex.size());
random::Shuffle(findex); random::Shuffle(findex);
findex.resize(n); findex.resize(n);
if (n != findex.size()) { // sync the findex if it is subsample
// sync the findex if it is subsample std::string s_cache;
std::string s_cache; utils::MemoryBufferStream fc(&s_cache);
utils::MemoryBufferStream fc(&s_cache); utils::IStream &fs = fc;
utils::IStream &fs = fc; if (rabit::GetRank() == 0) {
if (rabit::GetRank() == 0) { fs.Write(findex);
fs.Write(findex);
rabit::Broadcast(&s_cache, 0);
} else {
rabit::Broadcast(&s_cache, 0);
fs.Read(&findex);
}
} }
rabit::Broadcast(&s_cache, 0);
fs.Read(&findex);
} }
private: private:

View File

@ -40,12 +40,11 @@ class TreeSyncher: public IUpdater {
for (size_t i = 0; i < trees.size(); ++i) { for (size_t i = 0; i < trees.size(); ++i) {
trees[i]->SaveModel(fs); trees[i]->SaveModel(fs);
} }
rabit::Broadcast(&s_model, 0); }
} else { fs.Seek(0);
rabit::Broadcast(&s_model, 0); rabit::Broadcast(&s_model, 0);
for (size_t i = 0; i < trees.size(); ++i) { for (size_t i = 0; i < trees.size(); ++i) {
trees[i]->LoadModel(fs); trees[i]->LoadModel(fs);
}
} }
} }
}; };

View File

@ -284,8 +284,8 @@ class BoostLearnTask {
} }
int main(int argc, char *argv[]){ int main(int argc, char *argv[]){
xgboost::random::Seed(0);
xgboost::BoostLearnTask tsk; xgboost::BoostLearnTask tsk;
tsk.SetParam("seed", "0");
int ret = tsk.Run(argc, argv); int ret = tsk.Run(argc, argv);
rabit::Finalize(); rabit::Finalize();
return ret; return ret;