Avoid resetting seed for every configuration. (#6349)

This commit is contained in:
Jiaming Yuan
2020-11-06 10:28:35 +08:00
committed by GitHub
parent f3a4253984
commit 519cee115a
4 changed files with 41 additions and 4 deletions

View File

@@ -202,6 +202,7 @@ DMLC_REGISTER_PARAMETER(LearnerTrainParam);
DMLC_REGISTER_PARAMETER(GenericParameter);
int constexpr GenericParameter::kCpuId;
int64_t constexpr GenericParameter::kDefaultSeed;
void GenericParameter::ConfigureGpuId(bool require_gpu) {
#if defined(XGBOOST_USE_CUDA)
@@ -239,6 +240,9 @@ using ThreadLocalPredictionCache =
dmlc::ThreadLocalStore<std::map<Learner const *, PredictionContainer>>;
class LearnerConfiguration : public Learner {
private:
std::mutex config_lock_;
protected:
static std::string const kEvalMetric; // NOLINT
@@ -252,7 +256,6 @@ class LearnerConfiguration : public Learner {
LearnerModelParam learner_model_param_;
LearnerTrainParam tparam_;
std::vector<std::string> metric_names_;
std::mutex config_lock_;
public:
explicit LearnerConfiguration(std::vector<std::shared_ptr<DMatrix> > cache)
@@ -283,7 +286,11 @@ class LearnerConfiguration : public Learner {
tparam_.UpdateAllowUnknown(args);
auto mparam_backup = mparam_;
mparam_.UpdateAllowUnknown(args);
auto initialized = generic_parameters_.GetInitialised();
auto old_seed = generic_parameters_.seed;
generic_parameters_.UpdateAllowUnknown(args);
generic_parameters_.CheckDeprecated();
@@ -297,7 +304,9 @@ class LearnerConfiguration : public Learner {
}
// set seed only before the model is initialized
common::GlobalRandom().seed(generic_parameters_.seed);
if (!initialized || generic_parameters_.seed != old_seed) {
common::GlobalRandom().seed(generic_parameters_.seed);
}
// must precede configure gbm since num_features is required for gbm
this->ConfigureNumFeatures();