From 91429bd63d00fd090d9b74ddf8cdae9af6a5ebed Mon Sep 17 00:00:00 2001 From: Groves Date: Thu, 3 Dec 2015 06:40:11 -0600 Subject: [PATCH 1/2] Expose model parameters to R --- R-package/R/xgb.train.R | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/R-package/R/xgb.train.R b/R-package/R/xgb.train.R index 768bed27b..d7fa6e1ee 100644 --- a/R-package/R/xgb.train.R +++ b/R-package/R/xgb.train.R @@ -140,6 +140,7 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(), warning('watchlist is provided but verbose=0, no evaluation information will be printed') } + fit.call <- match.call() dot.params <- list(...) nms.params <- names(params) nms.dot.params <- names(dot.params) @@ -224,9 +225,13 @@ xgb.train <- function(params=list(), data, nrounds, watchlist = list(), } } bst <- xgb.Booster.check(bst) + if (!is.null(early.stop.round)) { bst$bestScore <- bestScore bst$bestInd <- bestInd } + + attr(bst, "call") <- fit.call + attr(bst, "params") <- params return(bst) } From cd57ea27844c0e40065b7ec516d9c71067947e08 Mon Sep 17 00:00:00 2001 From: Groves Date: Wed, 16 Dec 2015 10:24:16 -0600 Subject: [PATCH 2/2] Add test that model paramaters are accessible within R --- .../tests/testthat/test_parameter_exposure.R | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 R-package/tests/testthat/test_parameter_exposure.R diff --git a/R-package/tests/testthat/test_parameter_exposure.R b/R-package/tests/testthat/test_parameter_exposure.R new file mode 100644 index 000000000..769059b76 --- /dev/null +++ b/R-package/tests/testthat/test_parameter_exposure.R @@ -0,0 +1,32 @@ +context('Test model params and call are exposed to R') + +require(xgboost) + +data(agaricus.train, package='xgboost') +data(agaricus.test, package='xgboost') + +dtrain <- xgb.DMatrix(agaricus.train$data, label = agaricus.train$label) +dtest <- xgb.DMatrix(agaricus.test$data, label = agaricus.test$label) + +bst <- xgboost(data = dtrain, + max.depth = 2, + eta = 1, + nround = 10, + nthread = 1, + verbose = 0, + objective = "binary:logistic") + +test_that("call is exposed to R", { + model_call <- attr(bst, "call") + expect_is(model_call, "call") +}) + +test_that("params is exposed to R", { + model_params <- attr(bst, "params") + + expect_is(model_params, "list") + + expect_equal(model_params$eta, 1) + expect_equal(model_params$max.depth, 2) + expect_equal(model_params$objective, "binary:logistic") +})