From a730c7e67e24f267d8352771edc21c9bb295f126 Mon Sep 17 00:00:00 2001 From: david-cortes Date: Sun, 4 Feb 2024 09:22:22 +0100 Subject: [PATCH] [R] allow using seed with regular RNG (#10029) --- R-package/R/xgb.train.R | 9 +++- R-package/man/xgb.train.Rd | 5 +++ R-package/src/xgboost_custom.cc | 10 ----- R-package/tests/testthat/test_basic.R | 63 +++++++++++++++++++++++++++ doc/parameter.rst | 2 +- include/xgboost/base.h | 2 +- src/common/random.h | 2 +- 7 files changed, 78 insertions(+), 15 deletions(-) diff --git a/R-package/R/xgb.train.R b/R-package/R/xgb.train.R index f0f2332b5..44cde2e7a 100644 --- a/R-package/R/xgb.train.R +++ b/R-package/R/xgb.train.R @@ -178,6 +178,11 @@ #' Number of threads can also be manually specified via the \code{nthread} #' parameter. #' +#' While in other interfaces, the default random seed defaults to zero, in R, if a parameter `seed` +#' is not manually supplied, it will generate a random seed through R's own random number generator, +#' whose seed in turn is controllable through `set.seed`. If `seed` is passed, it will override the +#' RNG from R. +#' #' The evaluation metric is chosen automatically by XGBoost (according to the objective) #' when the \code{eval_metric} parameter is not provided. #' User may set one or several \code{eval_metric} parameters. @@ -363,8 +368,8 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(), # Sort the callbacks into categories cb <- categorize.callbacks(callbacks) params['validate_parameters'] <- TRUE - if (!is.null(params[['seed']])) { - warning("xgb.train: `seed` is ignored in R package. Use `set.seed()` instead.") + if (!("seed" %in% names(params))) { + params[["seed"]] <- sample(.Machine$integer.max, size = 1) } # The tree updating process would need slightly different handling diff --git a/R-package/man/xgb.train.Rd b/R-package/man/xgb.train.Rd index 0421b9c4a..21c5fe7ee 100644 --- a/R-package/man/xgb.train.Rd +++ b/R-package/man/xgb.train.Rd @@ -241,6 +241,11 @@ Parallelization is automatically enabled if \code{OpenMP} is present. Number of threads can also be manually specified via the \code{nthread} parameter. +While in other interfaces, the default random seed defaults to zero, in R, if a parameter \code{seed} +is not manually supplied, it will generate a random seed through R's own random number generator, +whose seed in turn is controllable through \code{set.seed}. If \code{seed} is passed, it will override the +RNG from R. + The evaluation metric is chosen automatically by XGBoost (according to the objective) when the \code{eval_metric} parameter is not provided. User may set one or several \code{eval_metric} parameters. diff --git a/R-package/src/xgboost_custom.cc b/R-package/src/xgboost_custom.cc index 6aaa3696a..fdd444e5d 100644 --- a/R-package/src/xgboost_custom.cc +++ b/R-package/src/xgboost_custom.cc @@ -41,16 +41,6 @@ double LogGamma(double v) { return lgammafn(v); } #endif // !defined(XGBOOST_USE_CUDA) -// customize random engine. -void CustomGlobalRandomEngine::seed(CustomGlobalRandomEngine::result_type val) { - // ignore the seed -} -// use R's PRNG to replacd -CustomGlobalRandomEngine::result_type -CustomGlobalRandomEngine::operator()() { - return static_cast( - std::floor(unif_rand() * CustomGlobalRandomEngine::max())); -} } // namespace common } // namespace xgboost diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index 03a8ddbe1..fb3162e42 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -778,3 +778,66 @@ test_that("DMatrix field are set to booster when training", { expect_equal(getinfo(model_feature_types, "feature_type"), c("q", "c", "q")) expect_equal(getinfo(model_both, "feature_type"), c("q", "c", "q")) }) + +test_that("Seed in params override PRNG from R", { + set.seed(123) + model1 <- xgb.train( + data = xgb.DMatrix( + agaricus.train$data, + label = agaricus.train$label, nthread = 1L + ), + params = list( + objective = "binary:logistic", + max_depth = 3L, + subsample = 0.1, + colsample_bytree = 0.1, + seed = 111L + ), + nrounds = 3L + ) + + set.seed(456) + model2 <- xgb.train( + data = xgb.DMatrix( + agaricus.train$data, + label = agaricus.train$label, nthread = 1L + ), + params = list( + objective = "binary:logistic", + max_depth = 3L, + subsample = 0.1, + colsample_bytree = 0.1, + seed = 111L + ), + nrounds = 3L + ) + + expect_equal( + xgb.save.raw(model1, raw_format = "json"), + xgb.save.raw(model2, raw_format = "json") + ) + + set.seed(123) + model3 <- xgb.train( + data = xgb.DMatrix( + agaricus.train$data, + label = agaricus.train$label, nthread = 1L + ), + params = list( + objective = "binary:logistic", + max_depth = 3L, + subsample = 0.1, + colsample_bytree = 0.1, + seed = 222L + ), + nrounds = 3L + ) + expect_false( + isTRUE( + all.equal( + xgb.save.raw(model1, raw_format = "json"), + xgb.save.raw(model3, raw_format = "json") + ) + ) + ) +}) diff --git a/doc/parameter.rst b/doc/parameter.rst index a7d8203b0..7898bb363 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -450,7 +450,7 @@ Specify the learning task and the corresponding learning objective. The objectiv * ``seed`` [default=0] - - Random number seed. This parameter is ignored in R package, use `set.seed()` instead. + - Random number seed. In the R package, if not specified, instead of defaulting to seed 'zero', will take a random seed through R's own RNG engine. * ``seed_per_iteration`` [default= ``false``] diff --git a/include/xgboost/base.h b/include/xgboost/base.h index dec306f0c..1f94c9b2f 100644 --- a/include/xgboost/base.h +++ b/include/xgboost/base.h @@ -37,7 +37,7 @@ * \brief Whether to customize global PRNG. */ #ifndef XGBOOST_CUSTOMIZE_GLOBAL_PRNG -#define XGBOOST_CUSTOMIZE_GLOBAL_PRNG XGBOOST_STRICT_R_MODE +#define XGBOOST_CUSTOMIZE_GLOBAL_PRNG 0 #endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG /*! diff --git a/src/common/random.h b/src/common/random.h index 2a94123a3..ece6fa46f 100644 --- a/src/common/random.h +++ b/src/common/random.h @@ -31,7 +31,7 @@ namespace xgboost::common { */ using RandomEngine = std::mt19937; -#if XGBOOST_CUSTOMIZE_GLOBAL_PRNG +#if defined(XGBOOST_CUSTOMIZE_GLOBAL_PRNG) && XGBOOST_CUSTOMIZE_GLOBAL_PRNG == 1 /*! * \brief An customized random engine, used to be plugged in PRNG from other systems. * The implementation of this library is not provided by xgboost core library.