diff --git a/R-package/NAMESPACE b/R-package/NAMESPACE index 11a812c9e..cd2a3b92b 100644 --- a/R-package/NAMESPACE +++ b/R-package/NAMESPACE @@ -14,6 +14,7 @@ S3method(setinfo,xgb.DMatrix) S3method(slice,xgb.DMatrix) export("xgb.attr<-") export("xgb.attributes<-") +export("xgb.config<-") export("xgb.parameters<-") export(cb.cv.predict) export(cb.early.stop) @@ -30,6 +31,7 @@ export(xgb.DMatrix) export(xgb.DMatrix.save) export(xgb.attr) export(xgb.attributes) +export(xgb.config) export(xgb.create.features) export(xgb.cv) export(xgb.dump) diff --git a/R-package/R/xgb.Booster.R b/R-package/R/xgb.Booster.R index 660264e0b..dd901b07d 100644 --- a/R-package/R/xgb.Booster.R +++ b/R-package/R/xgb.Booster.R @@ -503,6 +503,35 @@ xgb.attributes <- function(object) { object } +#' Accessors for model parameters as JSON string. +#' +#' @param object Object of class \code{xgb.Booster} +#' @param value A JSON string. +#' +#' @examples +#' data(agaricus.train, package='xgboost') +#' train <- agaricus.train +#' +#' bst <- xgboost(data = train$data, label = train$label, max_depth = 2, +#' eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic") +#' config <- xgb.config(bst) +#' +#' @rdname xgb.config +#' @export +xgb.config <- function(object) { + handle <- xgb.get.handle(object) + .Call(XGBoosterSaveJsonConfig_R, handle); +} + +#' @rdname xgb.config +#' @export +`xgb.config<-` <- function(object, value) { + handle <- xgb.get.handle(object) + .Call(XGBoosterLoadJsonConfig_R, handle, value) + object$raw <- xgb.Booster.complete(object) + object +} + #' Accessors for model parameters. #' #' Only the setter for xgboost parameters is currently implemented. diff --git a/R-package/man/xgb.config.Rd b/R-package/man/xgb.config.Rd new file mode 100644 index 000000000..a5187c8ea --- /dev/null +++ b/R-package/man/xgb.config.Rd @@ -0,0 +1,28 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/xgb.Booster.R +\name{xgb.config} +\alias{xgb.config} +\alias{xgb.config<-} +\title{Accessors for model parameters as JSON string.} +\usage{ +xgb.config(object) + +xgb.config(object) <- value +} +\arguments{ +\item{object}{Object of class \code{xgb.Booster}} + +\item{value}{A JSON string.} +} +\description{ +Accessors for model parameters as JSON string. +} +\examples{ +data(agaricus.train, package='xgboost') +train <- agaricus.train + +bst <- xgboost(data = train$data, label = train$label, max_depth = 2, + eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic") +config <- xgb.config(bst) + +} diff --git a/R-package/src/init.c b/R-package/src/init.c index 82b853217..b85d4e756 100644 --- a/R-package/src/init.c +++ b/R-package/src/init.c @@ -23,6 +23,8 @@ extern SEXP XGBoosterGetAttrNames_R(SEXP); extern SEXP XGBoosterGetAttr_R(SEXP, SEXP); extern SEXP XGBoosterLoadModelFromRaw_R(SEXP, SEXP); extern SEXP XGBoosterLoadModel_R(SEXP, SEXP); +extern SEXP XGBoosterSaveJsonConfig_R(SEXP handle); +extern SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value); extern SEXP XGBoosterModelToRaw_R(SEXP); extern SEXP XGBoosterPredict_R(SEXP, SEXP, SEXP, SEXP, SEXP); extern SEXP XGBoosterSaveModel_R(SEXP, SEXP); @@ -49,6 +51,8 @@ static const R_CallMethodDef CallEntries[] = { {"XGBoosterGetAttr_R", (DL_FUNC) &XGBoosterGetAttr_R, 2}, {"XGBoosterLoadModelFromRaw_R", (DL_FUNC) &XGBoosterLoadModelFromRaw_R, 2}, {"XGBoosterLoadModel_R", (DL_FUNC) &XGBoosterLoadModel_R, 2}, + {"XGBoosterSaveJsonConfig_R", (DL_FUNC) &XGBoosterSaveJsonConfig_R, 1}, + {"XGBoosterLoadJsonConfig_R", (DL_FUNC) &XGBoosterLoadJsonConfig_R, 2}, {"XGBoosterModelToRaw_R", (DL_FUNC) &XGBoosterModelToRaw_R, 1}, {"XGBoosterPredict_R", (DL_FUNC) &XGBoosterPredict_R, 5}, {"XGBoosterSaveModel_R", (DL_FUNC) &XGBoosterSaveModel_R, 2}, diff --git a/R-package/src/xgboost_R.cc b/R-package/src/xgboost_R.cc index c9083177d..d89ea7bed 100644 --- a/R-package/src/xgboost_R.cc +++ b/R-package/src/xgboost_R.cc @@ -362,6 +362,24 @@ SEXP XGBoosterModelToRaw_R(SEXP handle) { return ret; } +SEXP XGBoosterSaveJsonConfig_R(SEXP handle) { + const char* ret; + R_API_BEGIN(); + bst_ulong len {0}; + CHECK_CALL(XGBoosterSaveJsonConfig(R_ExternalPtrAddr(handle), + &len, + &ret)); + R_API_END(); + return mkString(ret); +} + +SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value) { + R_API_BEGIN(); + XGBoosterLoadJsonConfig(R_ExternalPtrAddr(handle), CHAR(asChar(value))); + R_API_END(); + return R_NilValue; +} + SEXP XGBoosterDumpModel_R(SEXP handle, SEXP fmap, SEXP with_stats, SEXP dump_format) { SEXP out; R_API_BEGIN(); diff --git a/R-package/src/xgboost_R.h b/R-package/src/xgboost_R.h index 764050fd8..05cd7afbe 100644 --- a/R-package/src/xgboost_R.h +++ b/R-package/src/xgboost_R.h @@ -179,9 +179,22 @@ XGB_DLL SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw); * \brief save model into R's raw array * \param handle handle * \return raw array - */ + */ XGB_DLL SEXP XGBoosterModelToRaw_R(SEXP handle); +/*! + * \brief Save internal parameters as a JSON string + * \param handle handle + * \return JSON string + */ +XGB_DLL SEXP XGBoosterSaveJsonConfig_R(SEXP handle); +/*! + * \brief Load the JSON string returnd by XGBoosterSaveJsonConfig_R + * \param handle handle + * \param value JSON string + * \return R_NilValue + */ +XGB_DLL SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value); /*! * \brief dump model into a string * \param handle handle diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index 97b90f7a1..eb881ed46 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -324,3 +324,13 @@ test_that("colsample_bytree works", { # in the 100 trees expect_gte(nrow(xgb.importance(model = bst)), 30) }) + +test_that("Configuration works", { + bst <- xgboost(data = train$data, label = train$label, max_depth = 2, + eta = 1, nthread = 2, nrounds = 2, objective = "binary:logistic", + eval_metric = 'error', eval_metric = 'auc', eval_metric = "logloss") + config <- xgb.config(bst) + xgb.config(bst) <- config + reloaded_config <- xgb.config(bst) + expect_equal(config, reloaded_config); +}) diff --git a/doc/tutorials/saving_model.rst b/doc/tutorials/saving_model.rst index 7d416ccb1..88a097ac1 100644 --- a/doc/tutorials/saving_model.rst +++ b/doc/tutorials/saving_model.rst @@ -102,7 +102,7 @@ comments in the script for more details. Saving and Loading the internal parameters configuration ******************************************************** -XGBoost's ``C API`` and ``Python API`` supports saving and loading the internal +XGBoost's ``C API``, ``Python API`` and ``R API`` support saving and loading the internal configuration directly as a JSON string. In Python package: .. code-block:: python @@ -111,6 +111,14 @@ configuration directly as a JSON string. In Python package: config = bst.save_config() print(config) + +or + +.. code-block:: R + + config <- xgb.config(bst) + print(config) + Will print out something similiar to (not actual output as it's too long for demonstration): .. code-block:: json