fix the row split recovery, add per iteration random number seed
This commit is contained in:
@@ -34,6 +34,8 @@ class BoostLearner : public rabit::ISerializable {
|
||||
prob_buffer_row = 1.0f;
|
||||
distributed_mode = 0;
|
||||
pred_buffer_size = 0;
|
||||
seed_per_iteration = 0;
|
||||
seed = 0;
|
||||
}
|
||||
virtual ~BoostLearner(void) {
|
||||
if (obj_ != NULL) delete obj_;
|
||||
@@ -102,7 +104,10 @@ class BoostLearner : public rabit::ISerializable {
|
||||
this->SetParam("updater", "grow_colmaker,refresh,prune");
|
||||
}
|
||||
if (!strcmp(name, "eval_metric")) evaluator_.AddEval(val);
|
||||
if (!strcmp("seed", name)) random::Seed(atoi(val));
|
||||
if (!strcmp("seed", name)) {
|
||||
this->seed = seed; random::Seed(atoi(val));
|
||||
}
|
||||
if (!strcmp("seed_per_iter", name)) seed_per_iteration = atoi(val);
|
||||
if (!strcmp(name, "num_class")) this->SetParam("num_output_group", val);
|
||||
if (!strcmp(name, "nthread")) {
|
||||
omp_set_num_threads(atoi(val));
|
||||
@@ -222,6 +227,9 @@ class BoostLearner : public rabit::ISerializable {
|
||||
* \param p_train pointer to the data matrix
|
||||
*/
|
||||
inline void UpdateOneIter(int iter, const DMatrix &train) {
|
||||
if (seed_per_iteration || rabit::IsDistributed()) {
|
||||
random::Seed(this->seed * kRandSeedMagic);
|
||||
}
|
||||
this->PredictRaw(train, &preds_);
|
||||
obj_->GetGradient(preds_, train.info, iter, &gpair_);
|
||||
gbm_->DoBoost(train.fmat(), this->FindBufferOffset(train), train.info.info, &gpair_);
|
||||
@@ -369,6 +377,12 @@ class BoostLearner : public rabit::ISerializable {
|
||||
}
|
||||
};
|
||||
// data fields
|
||||
// stored random seed
|
||||
int seed;
|
||||
// whether seed the PRNG each iteration
|
||||
// this is important for restart from existing iterations
|
||||
// default set to no, but will auto switch on in distributed mode
|
||||
int seed_per_iteration;
|
||||
// silent during training
|
||||
int silent;
|
||||
// distributed learning mode, if any, 0:none, 1:col, 2:row
|
||||
@@ -397,6 +411,8 @@ class BoostLearner : public rabit::ISerializable {
|
||||
std::vector<bst_gpair> gpair_;
|
||||
|
||||
protected:
|
||||
// magic number to transform random seed
|
||||
const static int kRandSeedMagic = 127;
|
||||
// cache entry object that helps handle feature caching
|
||||
struct CacheEntry {
|
||||
const DMatrix *mat_;
|
||||
|
||||
Reference in New Issue
Block a user