Avoid resetting seed for every configuration. (#6349)
This commit is contained in:
parent
f3a4253984
commit
519cee115a
@ -412,6 +412,10 @@ Specify the learning task and the corresponding learning objective. The objectiv
|
|||||||
|
|
||||||
- Random number seed. This parameter is ignored in R package, use `set.seed()` instead.
|
- Random number seed. This parameter is ignored in R package, use `set.seed()` instead.
|
||||||
|
|
||||||
|
* ``seed_per_iteration`` [default=false]
|
||||||
|
|
||||||
|
- Seed PRNG determnisticly via iterator number, this option will be switched on automatically on distributed mode.
|
||||||
|
|
||||||
***********************
|
***********************
|
||||||
Command Line Parameters
|
Command Line Parameters
|
||||||
***********************
|
***********************
|
||||||
|
|||||||
@ -14,10 +14,11 @@ namespace xgboost {
|
|||||||
struct GenericParameter : public XGBoostParameter<GenericParameter> {
|
struct GenericParameter : public XGBoostParameter<GenericParameter> {
|
||||||
// Constant representing the device ID of CPU.
|
// Constant representing the device ID of CPU.
|
||||||
static int32_t constexpr kCpuId = -1;
|
static int32_t constexpr kCpuId = -1;
|
||||||
|
static int64_t constexpr kDefaultSeed = 0;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// stored random seed
|
// stored random seed
|
||||||
int64_t seed;
|
int64_t seed { kDefaultSeed };
|
||||||
// whether seed the PRNG each iteration
|
// whether seed the PRNG each iteration
|
||||||
bool seed_per_iteration;
|
bool seed_per_iteration;
|
||||||
// number of threads to use if OpenMP is enabled
|
// number of threads to use if OpenMP is enabled
|
||||||
@ -46,7 +47,7 @@ struct GenericParameter : public XGBoostParameter<GenericParameter> {
|
|||||||
|
|
||||||
// declare parameters
|
// declare parameters
|
||||||
DMLC_DECLARE_PARAMETER(GenericParameter) {
|
DMLC_DECLARE_PARAMETER(GenericParameter) {
|
||||||
DMLC_DECLARE_FIELD(seed).set_default(0).describe(
|
DMLC_DECLARE_FIELD(seed).set_default(kDefaultSeed).describe(
|
||||||
"Random number seed during training.");
|
"Random number seed during training.");
|
||||||
DMLC_DECLARE_ALIAS(seed, random_state);
|
DMLC_DECLARE_ALIAS(seed, random_state);
|
||||||
DMLC_DECLARE_FIELD(seed_per_iteration)
|
DMLC_DECLARE_FIELD(seed_per_iteration)
|
||||||
|
|||||||
@ -202,6 +202,7 @@ DMLC_REGISTER_PARAMETER(LearnerTrainParam);
|
|||||||
DMLC_REGISTER_PARAMETER(GenericParameter);
|
DMLC_REGISTER_PARAMETER(GenericParameter);
|
||||||
|
|
||||||
int constexpr GenericParameter::kCpuId;
|
int constexpr GenericParameter::kCpuId;
|
||||||
|
int64_t constexpr GenericParameter::kDefaultSeed;
|
||||||
|
|
||||||
void GenericParameter::ConfigureGpuId(bool require_gpu) {
|
void GenericParameter::ConfigureGpuId(bool require_gpu) {
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
@ -239,6 +240,9 @@ using ThreadLocalPredictionCache =
|
|||||||
dmlc::ThreadLocalStore<std::map<Learner const *, PredictionContainer>>;
|
dmlc::ThreadLocalStore<std::map<Learner const *, PredictionContainer>>;
|
||||||
|
|
||||||
class LearnerConfiguration : public Learner {
|
class LearnerConfiguration : public Learner {
|
||||||
|
private:
|
||||||
|
std::mutex config_lock_;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
static std::string const kEvalMetric; // NOLINT
|
static std::string const kEvalMetric; // NOLINT
|
||||||
|
|
||||||
@ -252,7 +256,6 @@ class LearnerConfiguration : public Learner {
|
|||||||
LearnerModelParam learner_model_param_;
|
LearnerModelParam learner_model_param_;
|
||||||
LearnerTrainParam tparam_;
|
LearnerTrainParam tparam_;
|
||||||
std::vector<std::string> metric_names_;
|
std::vector<std::string> metric_names_;
|
||||||
std::mutex config_lock_;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit LearnerConfiguration(std::vector<std::shared_ptr<DMatrix> > cache)
|
explicit LearnerConfiguration(std::vector<std::shared_ptr<DMatrix> > cache)
|
||||||
@ -283,7 +286,11 @@ class LearnerConfiguration : public Learner {
|
|||||||
|
|
||||||
tparam_.UpdateAllowUnknown(args);
|
tparam_.UpdateAllowUnknown(args);
|
||||||
auto mparam_backup = mparam_;
|
auto mparam_backup = mparam_;
|
||||||
|
|
||||||
mparam_.UpdateAllowUnknown(args);
|
mparam_.UpdateAllowUnknown(args);
|
||||||
|
|
||||||
|
auto initialized = generic_parameters_.GetInitialised();
|
||||||
|
auto old_seed = generic_parameters_.seed;
|
||||||
generic_parameters_.UpdateAllowUnknown(args);
|
generic_parameters_.UpdateAllowUnknown(args);
|
||||||
generic_parameters_.CheckDeprecated();
|
generic_parameters_.CheckDeprecated();
|
||||||
|
|
||||||
@ -297,7 +304,9 @@ class LearnerConfiguration : public Learner {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// set seed only before the model is initialized
|
// set seed only before the model is initialized
|
||||||
common::GlobalRandom().seed(generic_parameters_.seed);
|
if (!initialized || generic_parameters_.seed != old_seed) {
|
||||||
|
common::GlobalRandom().seed(generic_parameters_.seed);
|
||||||
|
}
|
||||||
|
|
||||||
// must precede configure gbm since num_features is required for gbm
|
// must precede configure gbm since num_features is required for gbm
|
||||||
this->ConfigureNumFeatures();
|
this->ConfigureNumFeatures();
|
||||||
|
|||||||
@ -11,6 +11,7 @@
|
|||||||
#include <xgboost/version_config.h>
|
#include <xgboost/version_config.h>
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
#include "../../src/common/io.h"
|
#include "../../src/common/io.h"
|
||||||
|
#include "../../src/common/random.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
@ -333,4 +334,26 @@ TEST(Learner, Seed) {
|
|||||||
ASSERT_EQ(std::to_string(seed),
|
ASSERT_EQ(std::to_string(seed),
|
||||||
get<String>(config["learner"]["generic_param"]["seed"]));
|
get<String>(config["learner"]["generic_param"]["seed"]));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Learner, ConstantSeed) {
|
||||||
|
auto m = RandomDataGenerator{10, 10, 0}.GenerateDMatrix(true);
|
||||||
|
std::unique_ptr<Learner> learner{Learner::Create({m})};
|
||||||
|
learner->Configure(); // seed the global random
|
||||||
|
|
||||||
|
std::uniform_real_distribution<float> dist;
|
||||||
|
auto& rng = common::GlobalRandom();
|
||||||
|
float v_0 = dist(rng);
|
||||||
|
|
||||||
|
learner->SetParam("", "");
|
||||||
|
learner->Configure(); // check configure doesn't change the seed.
|
||||||
|
float v_1 = dist(rng);
|
||||||
|
CHECK_NE(v_0, v_1);
|
||||||
|
|
||||||
|
{
|
||||||
|
rng.seed(GenericParameter::kDefaultSeed);
|
||||||
|
std::uniform_real_distribution<float> dist;
|
||||||
|
float v_2 = dist(rng);
|
||||||
|
CHECK_EQ(v_0, v_2);
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user