diff --git a/R-package/R/xgb.train.R b/R-package/R/xgb.train.R index 768bed27b..d7fa6e1ee 100644 --- a/R-package/R/xgb.train.R +++ b/R-package/R/xgb.train.R @@ -140,6 +140,7 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(), warning('watchlist is provided but verbose=0, no evaluation information will be printed') } + fit.call <- match.call() dot.params <- list(...) nms.params <- names(params) nms.dot.params <- names(dot.params) @@ -224,9 +225,13 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(), } } bst <- xgb.Booster.check(bst) + if (!is.null(early.stop.round)) { bst$bestScore <- bestScore bst$bestInd <- bestInd } + + attr(bst, "call") <- fit.call + attr(bst, "params") <- params return(bst) } diff --git a/R-package/tests/testthat/test_parameter_exposure.R b/R-package/tests/testthat/test_parameter_exposure.R new file mode 100644 index 000000000..769059b76 --- /dev/null +++ b/R-package/tests/testthat/test_parameter_exposure.R @@ -0,0 +1,32 @@ +context('Test model params and call are exposed to R') + +require(xgboost) + +data(agaricus.train, package='xgboost') +data(agaricus.test, package='xgboost') + +dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label) +dtest <- xgb.DMatrix(agaricus.test$data, label = agaricus.test$label) + +bst <- xgboost(data = dtrain, + max.depth = 2, + eta = 1, + nround = 10, + nthread = 1, + verbose = 0, + objective = "binary:logistic") + +test_that("call is exposed to R", { + model_call <- attr(bst, "call") + expect_is(model_call, "call") +}) + +test_that("params is exposed to R", { + model_params <- attr(bst, "params") + + expect_is(model_params, "list") + + expect_equal(model_params$eta, 1) + expect_equal(model_params$max.depth, 2) + expect_equal(model_params$objective, "binary:logistic") +})