[R] allow using seed with regular RNG (#10029)

This commit is contained in:
david-cortes 2024-02-04 09:22:22 +01:00 committed by GitHub
parent 662854c7d7
commit a730c7e67e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 78 additions and 15 deletions

View File

@ -178,6 +178,11 @@
#' Number of threads can also be manually specified via the \code{nthread} #' Number of threads can also be manually specified via the \code{nthread}
#' parameter. #' parameter.
#' #'
#' While in other interfaces, the default random seed defaults to zero, in R, if a parameter `seed`
#' is not manually supplied, it will generate a random seed through R's own random number generator,
#' whose seed in turn is controllable through `set.seed`. If `seed` is passed, it will override the
#' RNG from R.
#'
#' The evaluation metric is chosen automatically by XGBoost (according to the objective) #' The evaluation metric is chosen automatically by XGBoost (according to the objective)
#' when the \code{eval_metric} parameter is not provided. #' when the \code{eval_metric} parameter is not provided.
#' User may set one or several \code{eval_metric} parameters. #' User may set one or several \code{eval_metric} parameters.
@ -363,8 +368,8 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(),
# Sort the callbacks into categories # Sort the callbacks into categories
cb <- categorize.callbacks(callbacks) cb <- categorize.callbacks(callbacks)
params['validate_parameters'] <- TRUE params['validate_parameters'] <- TRUE
if (!is.null(params[['seed']])) { if (!("seed" %in% names(params))) {
warning("xgb.train: `seed` is ignored in R package. Use `set.seed()` instead.") params[["seed"]] <- sample(.Machine$integer.max, size = 1)
} }
# The tree updating process would need slightly different handling # The tree updating process would need slightly different handling

View File

@ -241,6 +241,11 @@ Parallelization is automatically enabled if \code{OpenMP} is present.
Number of threads can also be manually specified via the \code{nthread} Number of threads can also be manually specified via the \code{nthread}
parameter. parameter.
While in other interfaces, the default random seed defaults to zero, in R, if a parameter \code{seed}
is not manually supplied, it will generate a random seed through R's own random number generator,
whose seed in turn is controllable through \code{set.seed}. If \code{seed} is passed, it will override the
RNG from R.
The evaluation metric is chosen automatically by XGBoost (according to the objective) The evaluation metric is chosen automatically by XGBoost (according to the objective)
when the \code{eval_metric} parameter is not provided. when the \code{eval_metric} parameter is not provided.
User may set one or several \code{eval_metric} parameters. User may set one or several \code{eval_metric} parameters.

View File

@ -41,16 +41,6 @@ double LogGamma(double v) {
return lgammafn(v); return lgammafn(v);
} }
#endif // !defined(XGBOOST_USE_CUDA) #endif // !defined(XGBOOST_USE_CUDA)
// customize random engine.
void CustomGlobalRandomEngine::seed(CustomGlobalRandomEngine::result_type val) {
// ignore the seed
}
// use R's PRNG to replacd
CustomGlobalRandomEngine::result_type
CustomGlobalRandomEngine::operator()() {
return static_cast<result_type>(
std::floor(unif_rand() * CustomGlobalRandomEngine::max()));
}
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost

View File

@ -778,3 +778,66 @@ test_that("DMatrix field are set to booster when training", {
expect_equal(getinfo(model_feature_types, "feature_type"), c("q", "c", "q")) expect_equal(getinfo(model_feature_types, "feature_type"), c("q", "c", "q"))
expect_equal(getinfo(model_both, "feature_type"), c("q", "c", "q")) expect_equal(getinfo(model_both, "feature_type"), c("q", "c", "q"))
}) })
test_that("Seed in params override PRNG from R", {
set.seed(123)
model1 <- xgb.train(
data = xgb.DMatrix(
agaricus.train$data,
label = agaricus.train$label, nthread = 1L
),
params = list(
objective = "binary:logistic",
max_depth = 3L,
subsample = 0.1,
colsample_bytree = 0.1,
seed = 111L
),
nrounds = 3L
)
set.seed(456)
model2 <- xgb.train(
data = xgb.DMatrix(
agaricus.train$data,
label = agaricus.train$label, nthread = 1L
),
params = list(
objective = "binary:logistic",
max_depth = 3L,
subsample = 0.1,
colsample_bytree = 0.1,
seed = 111L
),
nrounds = 3L
)
expect_equal(
xgb.save.raw(model1, raw_format = "json"),
xgb.save.raw(model2, raw_format = "json")
)
set.seed(123)
model3 <- xgb.train(
data = xgb.DMatrix(
agaricus.train$data,
label = agaricus.train$label, nthread = 1L
),
params = list(
objective = "binary:logistic",
max_depth = 3L,
subsample = 0.1,
colsample_bytree = 0.1,
seed = 222L
),
nrounds = 3L
)
expect_false(
isTRUE(
all.equal(
xgb.save.raw(model1, raw_format = "json"),
xgb.save.raw(model3, raw_format = "json")
)
)
)
})

View File

@ -450,7 +450,7 @@ Specify the learning task and the corresponding learning objective. The objectiv
* ``seed`` [default=0] * ``seed`` [default=0]
- Random number seed. This parameter is ignored in R package, use `set.seed()` instead. - Random number seed. In the R package, if not specified, instead of defaulting to seed 'zero', will take a random seed through R's own RNG engine.
* ``seed_per_iteration`` [default= ``false``] * ``seed_per_iteration`` [default= ``false``]

View File

@ -37,7 +37,7 @@
* \brief Whether to customize global PRNG. * \brief Whether to customize global PRNG.
*/ */
#ifndef XGBOOST_CUSTOMIZE_GLOBAL_PRNG #ifndef XGBOOST_CUSTOMIZE_GLOBAL_PRNG
#define XGBOOST_CUSTOMIZE_GLOBAL_PRNG XGBOOST_STRICT_R_MODE #define XGBOOST_CUSTOMIZE_GLOBAL_PRNG 0
#endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG #endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG
/*! /*!

View File

@ -31,7 +31,7 @@ namespace xgboost::common {
*/ */
using RandomEngine = std::mt19937; using RandomEngine = std::mt19937;
#if XGBOOST_CUSTOMIZE_GLOBAL_PRNG #if defined(XGBOOST_CUSTOMIZE_GLOBAL_PRNG) && XGBOOST_CUSTOMIZE_GLOBAL_PRNG == 1
/*! /*!
* \brief An customized random engine, used to be plugged in PRNG from other systems. * \brief An customized random engine, used to be plugged in PRNG from other systems.
* The implementation of this library is not provided by xgboost core library. * The implementation of this library is not provided by xgboost core library.