[R] allow using seed with regular RNG (#10029)
This commit is contained in:
parent
662854c7d7
commit
a730c7e67e
@ -178,6 +178,11 @@
|
||||
#' Number of threads can also be manually specified via the \code{nthread}
|
||||
#' parameter.
|
||||
#'
|
||||
#' While in other interfaces, the default random seed defaults to zero, in R, if a parameter `seed`
|
||||
#' is not manually supplied, it will generate a random seed through R's own random number generator,
|
||||
#' whose seed in turn is controllable through `set.seed`. If `seed` is passed, it will override the
|
||||
#' RNG from R.
|
||||
#'
|
||||
#' The evaluation metric is chosen automatically by XGBoost (according to the objective)
|
||||
#' when the \code{eval_metric} parameter is not provided.
|
||||
#' User may set one or several \code{eval_metric} parameters.
|
||||
@ -363,8 +368,8 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(),
|
||||
# Sort the callbacks into categories
|
||||
cb <- categorize.callbacks(callbacks)
|
||||
params['validate_parameters'] <- TRUE
|
||||
if (!is.null(params[['seed']])) {
|
||||
warning("xgb.train: `seed` is ignored in R package. Use `set.seed()` instead.")
|
||||
if (!("seed" %in% names(params))) {
|
||||
params[["seed"]] <- sample(.Machine$integer.max, size = 1)
|
||||
}
|
||||
|
||||
# The tree updating process would need slightly different handling
|
||||
|
||||
@ -241,6 +241,11 @@ Parallelization is automatically enabled if \code{OpenMP} is present.
|
||||
Number of threads can also be manually specified via the \code{nthread}
|
||||
parameter.
|
||||
|
||||
While in other interfaces, the default random seed defaults to zero, in R, if a parameter \code{seed}
|
||||
is not manually supplied, it will generate a random seed through R's own random number generator,
|
||||
whose seed in turn is controllable through \code{set.seed}. If \code{seed} is passed, it will override the
|
||||
RNG from R.
|
||||
|
||||
The evaluation metric is chosen automatically by XGBoost (according to the objective)
|
||||
when the \code{eval_metric} parameter is not provided.
|
||||
User may set one or several \code{eval_metric} parameters.
|
||||
|
||||
@ -41,16 +41,6 @@ double LogGamma(double v) {
|
||||
return lgammafn(v);
|
||||
}
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
// customize random engine.
|
||||
void CustomGlobalRandomEngine::seed(CustomGlobalRandomEngine::result_type val) {
|
||||
// ignore the seed
|
||||
}
|
||||
|
||||
// use R's PRNG to replacd
|
||||
CustomGlobalRandomEngine::result_type
|
||||
CustomGlobalRandomEngine::operator()() {
|
||||
return static_cast<result_type>(
|
||||
std::floor(unif_rand() * CustomGlobalRandomEngine::max()));
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
@ -778,3 +778,66 @@ test_that("DMatrix field are set to booster when training", {
|
||||
expect_equal(getinfo(model_feature_types, "feature_type"), c("q", "c", "q"))
|
||||
expect_equal(getinfo(model_both, "feature_type"), c("q", "c", "q"))
|
||||
})
|
||||
|
||||
test_that("Seed in params override PRNG from R", {
|
||||
set.seed(123)
|
||||
model1 <- xgb.train(
|
||||
data = xgb.DMatrix(
|
||||
agaricus.train$data,
|
||||
label = agaricus.train$label, nthread = 1L
|
||||
),
|
||||
params = list(
|
||||
objective = "binary:logistic",
|
||||
max_depth = 3L,
|
||||
subsample = 0.1,
|
||||
colsample_bytree = 0.1,
|
||||
seed = 111L
|
||||
),
|
||||
nrounds = 3L
|
||||
)
|
||||
|
||||
set.seed(456)
|
||||
model2 <- xgb.train(
|
||||
data = xgb.DMatrix(
|
||||
agaricus.train$data,
|
||||
label = agaricus.train$label, nthread = 1L
|
||||
),
|
||||
params = list(
|
||||
objective = "binary:logistic",
|
||||
max_depth = 3L,
|
||||
subsample = 0.1,
|
||||
colsample_bytree = 0.1,
|
||||
seed = 111L
|
||||
),
|
||||
nrounds = 3L
|
||||
)
|
||||
|
||||
expect_equal(
|
||||
xgb.save.raw(model1, raw_format = "json"),
|
||||
xgb.save.raw(model2, raw_format = "json")
|
||||
)
|
||||
|
||||
set.seed(123)
|
||||
model3 <- xgb.train(
|
||||
data = xgb.DMatrix(
|
||||
agaricus.train$data,
|
||||
label = agaricus.train$label, nthread = 1L
|
||||
),
|
||||
params = list(
|
||||
objective = "binary:logistic",
|
||||
max_depth = 3L,
|
||||
subsample = 0.1,
|
||||
colsample_bytree = 0.1,
|
||||
seed = 222L
|
||||
),
|
||||
nrounds = 3L
|
||||
)
|
||||
expect_false(
|
||||
isTRUE(
|
||||
all.equal(
|
||||
xgb.save.raw(model1, raw_format = "json"),
|
||||
xgb.save.raw(model3, raw_format = "json")
|
||||
)
|
||||
)
|
||||
)
|
||||
})
|
||||
|
||||
@ -450,7 +450,7 @@ Specify the learning task and the corresponding learning objective. The objectiv
|
||||
|
||||
* ``seed`` [default=0]
|
||||
|
||||
- Random number seed. This parameter is ignored in R package, use `set.seed()` instead.
|
||||
- Random number seed. In the R package, if not specified, instead of defaulting to seed 'zero', will take a random seed through R's own RNG engine.
|
||||
|
||||
* ``seed_per_iteration`` [default= ``false``]
|
||||
|
||||
|
||||
@ -37,7 +37,7 @@
|
||||
* \brief Whether to customize global PRNG.
|
||||
*/
|
||||
#ifndef XGBOOST_CUSTOMIZE_GLOBAL_PRNG
|
||||
#define XGBOOST_CUSTOMIZE_GLOBAL_PRNG XGBOOST_STRICT_R_MODE
|
||||
#define XGBOOST_CUSTOMIZE_GLOBAL_PRNG 0
|
||||
#endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG
|
||||
|
||||
/*!
|
||||
|
||||
@ -31,7 +31,7 @@ namespace xgboost::common {
|
||||
*/
|
||||
using RandomEngine = std::mt19937;
|
||||
|
||||
#if XGBOOST_CUSTOMIZE_GLOBAL_PRNG
|
||||
#if defined(XGBOOST_CUSTOMIZE_GLOBAL_PRNG) && XGBOOST_CUSTOMIZE_GLOBAL_PRNG == 1
|
||||
/*!
|
||||
* \brief An customized random engine, used to be plugged in PRNG from other systems.
|
||||
* The implementation of this library is not provided by xgboost core library.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user