[R] allow using seed with regular RNG (#10029)
This commit is contained in:
parent
662854c7d7
commit
a730c7e67e
@ -178,6 +178,11 @@
|
|||||||
#' Number of threads can also be manually specified via the \code{nthread}
|
#' Number of threads can also be manually specified via the \code{nthread}
|
||||||
#' parameter.
|
#' parameter.
|
||||||
#'
|
#'
|
||||||
|
#' While in other interfaces, the default random seed defaults to zero, in R, if a parameter `seed`
|
||||||
|
#' is not manually supplied, it will generate a random seed through R's own random number generator,
|
||||||
|
#' whose seed in turn is controllable through `set.seed`. If `seed` is passed, it will override the
|
||||||
|
#' RNG from R.
|
||||||
|
#'
|
||||||
#' The evaluation metric is chosen automatically by XGBoost (according to the objective)
|
#' The evaluation metric is chosen automatically by XGBoost (according to the objective)
|
||||||
#' when the \code{eval_metric} parameter is not provided.
|
#' when the \code{eval_metric} parameter is not provided.
|
||||||
#' User may set one or several \code{eval_metric} parameters.
|
#' User may set one or several \code{eval_metric} parameters.
|
||||||
@ -363,8 +368,8 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(),
|
|||||||
# Sort the callbacks into categories
|
# Sort the callbacks into categories
|
||||||
cb <- categorize.callbacks(callbacks)
|
cb <- categorize.callbacks(callbacks)
|
||||||
params['validate_parameters'] <- TRUE
|
params['validate_parameters'] <- TRUE
|
||||||
if (!is.null(params[['seed']])) {
|
if (!("seed" %in% names(params))) {
|
||||||
warning("xgb.train: `seed` is ignored in R package. Use `set.seed()` instead.")
|
params[["seed"]] <- sample(.Machine$integer.max, size = 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
# The tree updating process would need slightly different handling
|
# The tree updating process would need slightly different handling
|
||||||
|
|||||||
@ -241,6 +241,11 @@ Parallelization is automatically enabled if \code{OpenMP} is present.
|
|||||||
Number of threads can also be manually specified via the \code{nthread}
|
Number of threads can also be manually specified via the \code{nthread}
|
||||||
parameter.
|
parameter.
|
||||||
|
|
||||||
|
While in other interfaces, the default random seed defaults to zero, in R, if a parameter \code{seed}
|
||||||
|
is not manually supplied, it will generate a random seed through R's own random number generator,
|
||||||
|
whose seed in turn is controllable through \code{set.seed}. If \code{seed} is passed, it will override the
|
||||||
|
RNG from R.
|
||||||
|
|
||||||
The evaluation metric is chosen automatically by XGBoost (according to the objective)
|
The evaluation metric is chosen automatically by XGBoost (according to the objective)
|
||||||
when the \code{eval_metric} parameter is not provided.
|
when the \code{eval_metric} parameter is not provided.
|
||||||
User may set one or several \code{eval_metric} parameters.
|
User may set one or several \code{eval_metric} parameters.
|
||||||
|
|||||||
@ -41,16 +41,6 @@ double LogGamma(double v) {
|
|||||||
return lgammafn(v);
|
return lgammafn(v);
|
||||||
}
|
}
|
||||||
#endif // !defined(XGBOOST_USE_CUDA)
|
#endif // !defined(XGBOOST_USE_CUDA)
|
||||||
// customize random engine.
|
|
||||||
void CustomGlobalRandomEngine::seed(CustomGlobalRandomEngine::result_type val) {
|
|
||||||
// ignore the seed
|
|
||||||
}
|
|
||||||
|
|
||||||
// use R's PRNG to replacd
|
|
||||||
CustomGlobalRandomEngine::result_type
|
|
||||||
CustomGlobalRandomEngine::operator()() {
|
|
||||||
return static_cast<result_type>(
|
|
||||||
std::floor(unif_rand() * CustomGlobalRandomEngine::max()));
|
|
||||||
}
|
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -778,3 +778,66 @@ test_that("DMatrix field are set to booster when training", {
|
|||||||
expect_equal(getinfo(model_feature_types, "feature_type"), c("q", "c", "q"))
|
expect_equal(getinfo(model_feature_types, "feature_type"), c("q", "c", "q"))
|
||||||
expect_equal(getinfo(model_both, "feature_type"), c("q", "c", "q"))
|
expect_equal(getinfo(model_both, "feature_type"), c("q", "c", "q"))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test_that("Seed in params override PRNG from R", {
|
||||||
|
set.seed(123)
|
||||||
|
model1 <- xgb.train(
|
||||||
|
data = xgb.DMatrix(
|
||||||
|
agaricus.train$data,
|
||||||
|
label = agaricus.train$label, nthread = 1L
|
||||||
|
),
|
||||||
|
params = list(
|
||||||
|
objective = "binary:logistic",
|
||||||
|
max_depth = 3L,
|
||||||
|
subsample = 0.1,
|
||||||
|
colsample_bytree = 0.1,
|
||||||
|
seed = 111L
|
||||||
|
),
|
||||||
|
nrounds = 3L
|
||||||
|
)
|
||||||
|
|
||||||
|
set.seed(456)
|
||||||
|
model2 <- xgb.train(
|
||||||
|
data = xgb.DMatrix(
|
||||||
|
agaricus.train$data,
|
||||||
|
label = agaricus.train$label, nthread = 1L
|
||||||
|
),
|
||||||
|
params = list(
|
||||||
|
objective = "binary:logistic",
|
||||||
|
max_depth = 3L,
|
||||||
|
subsample = 0.1,
|
||||||
|
colsample_bytree = 0.1,
|
||||||
|
seed = 111L
|
||||||
|
),
|
||||||
|
nrounds = 3L
|
||||||
|
)
|
||||||
|
|
||||||
|
expect_equal(
|
||||||
|
xgb.save.raw(model1, raw_format = "json"),
|
||||||
|
xgb.save.raw(model2, raw_format = "json")
|
||||||
|
)
|
||||||
|
|
||||||
|
set.seed(123)
|
||||||
|
model3 <- xgb.train(
|
||||||
|
data = xgb.DMatrix(
|
||||||
|
agaricus.train$data,
|
||||||
|
label = agaricus.train$label, nthread = 1L
|
||||||
|
),
|
||||||
|
params = list(
|
||||||
|
objective = "binary:logistic",
|
||||||
|
max_depth = 3L,
|
||||||
|
subsample = 0.1,
|
||||||
|
colsample_bytree = 0.1,
|
||||||
|
seed = 222L
|
||||||
|
),
|
||||||
|
nrounds = 3L
|
||||||
|
)
|
||||||
|
expect_false(
|
||||||
|
isTRUE(
|
||||||
|
all.equal(
|
||||||
|
xgb.save.raw(model1, raw_format = "json"),
|
||||||
|
xgb.save.raw(model3, raw_format = "json")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|||||||
@ -450,7 +450,7 @@ Specify the learning task and the corresponding learning objective. The objectiv
|
|||||||
|
|
||||||
* ``seed`` [default=0]
|
* ``seed`` [default=0]
|
||||||
|
|
||||||
- Random number seed. This parameter is ignored in R package, use `set.seed()` instead.
|
- Random number seed. In the R package, if not specified, instead of defaulting to seed 'zero', will take a random seed through R's own RNG engine.
|
||||||
|
|
||||||
* ``seed_per_iteration`` [default= ``false``]
|
* ``seed_per_iteration`` [default= ``false``]
|
||||||
|
|
||||||
|
|||||||
@ -37,7 +37,7 @@
|
|||||||
* \brief Whether to customize global PRNG.
|
* \brief Whether to customize global PRNG.
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_CUSTOMIZE_GLOBAL_PRNG
|
#ifndef XGBOOST_CUSTOMIZE_GLOBAL_PRNG
|
||||||
#define XGBOOST_CUSTOMIZE_GLOBAL_PRNG XGBOOST_STRICT_R_MODE
|
#define XGBOOST_CUSTOMIZE_GLOBAL_PRNG 0
|
||||||
#endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG
|
#endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
|
|||||||
@ -31,7 +31,7 @@ namespace xgboost::common {
|
|||||||
*/
|
*/
|
||||||
using RandomEngine = std::mt19937;
|
using RandomEngine = std::mt19937;
|
||||||
|
|
||||||
#if XGBOOST_CUSTOMIZE_GLOBAL_PRNG
|
#if defined(XGBOOST_CUSTOMIZE_GLOBAL_PRNG) && XGBOOST_CUSTOMIZE_GLOBAL_PRNG == 1
|
||||||
/*!
|
/*!
|
||||||
* \brief An customized random engine, used to be plugged in PRNG from other systems.
|
* \brief An customized random engine, used to be plugged in PRNG from other systems.
|
||||||
* The implementation of this library is not provided by xgboost core library.
|
* The implementation of this library is not provided by xgboost core library.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user