[R] allow using seed with regular RNG (#10029)
This commit is contained in:
@@ -778,3 +778,66 @@ test_that("DMatrix field are set to booster when training", {
|
||||
expect_equal(getinfo(model_feature_types, "feature_type"), c("q", "c", "q"))
|
||||
expect_equal(getinfo(model_both, "feature_type"), c("q", "c", "q"))
|
||||
})
|
||||
|
||||
test_that("Seed in params override PRNG from R", {
|
||||
set.seed(123)
|
||||
model1 <- xgb.train(
|
||||
data = xgb.DMatrix(
|
||||
agaricus.train$data,
|
||||
label = agaricus.train$label, nthread = 1L
|
||||
),
|
||||
params = list(
|
||||
objective = "binary:logistic",
|
||||
max_depth = 3L,
|
||||
subsample = 0.1,
|
||||
colsample_bytree = 0.1,
|
||||
seed = 111L
|
||||
),
|
||||
nrounds = 3L
|
||||
)
|
||||
|
||||
set.seed(456)
|
||||
model2 <- xgb.train(
|
||||
data = xgb.DMatrix(
|
||||
agaricus.train$data,
|
||||
label = agaricus.train$label, nthread = 1L
|
||||
),
|
||||
params = list(
|
||||
objective = "binary:logistic",
|
||||
max_depth = 3L,
|
||||
subsample = 0.1,
|
||||
colsample_bytree = 0.1,
|
||||
seed = 111L
|
||||
),
|
||||
nrounds = 3L
|
||||
)
|
||||
|
||||
expect_equal(
|
||||
xgb.save.raw(model1, raw_format = "json"),
|
||||
xgb.save.raw(model2, raw_format = "json")
|
||||
)
|
||||
|
||||
set.seed(123)
|
||||
model3 <- xgb.train(
|
||||
data = xgb.DMatrix(
|
||||
agaricus.train$data,
|
||||
label = agaricus.train$label, nthread = 1L
|
||||
),
|
||||
params = list(
|
||||
objective = "binary:logistic",
|
||||
max_depth = 3L,
|
||||
subsample = 0.1,
|
||||
colsample_bytree = 0.1,
|
||||
seed = 222L
|
||||
),
|
||||
nrounds = 3L
|
||||
)
|
||||
expect_false(
|
||||
isTRUE(
|
||||
all.equal(
|
||||
xgb.save.raw(model1, raw_format = "json"),
|
||||
xgb.save.raw(model3, raw_format = "json")
|
||||
)
|
||||
)
|
||||
)
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user