diff --git a/R-package/R/xgb.load.R b/R-package/R/xgb.load.R index bda4e7e07..d98041908 100644 --- a/R-package/R/xgb.load.R +++ b/R-package/R/xgb.load.R @@ -5,7 +5,7 @@ #' @param modelfile the name of the binary input file. #' #' @details -#' The input file is expected to contain a model saved in an xgboost-internal binary format +#' The input file is expected to contain a model saved in an xgboost model format #' using either \code{\link{xgb.save}} or \code{\link{cb.save.model}} in R, or using some #' appropriate methods from other xgboost interfaces. E.g., a model trained in Python and #' saved from there in xgboost format, could be loaded from R. @@ -38,6 +38,13 @@ xgb.load <- function(modelfile) { handle <- xgb.Booster.handle(modelfile = modelfile) # re-use modelfile if it is raw so we do not need to serialize if (typeof(modelfile) == "raw") { + warning( + paste( + "The support for loading raw booster with `xgb.load` will be ", + "discontinued in upcoming release. Use `xgb.load.raw` or", + " `xgb.unserialize` instead. " + ) + ) bst <- xgb.handleToBooster(handle, modelfile) } else { bst <- xgb.handleToBooster(handle, NULL) diff --git a/R-package/R/xgb.save.raw.R b/R-package/R/xgb.save.raw.R index 967a31482..48fdbca45 100644 --- a/R-package/R/xgb.save.raw.R +++ b/R-package/R/xgb.save.raw.R @@ -4,6 +4,14 @@ #' Save xgboost model from xgboost or xgb.train #' #' @param model the model object. +#' @param raw_format The format for encoding the booster. Available options are +#' \itemize{ +#' \item \code{json}: Encode the booster into JSON text document. +#' \item \code{ubj}: Encode the booster into Universal Binary JSON. +#' \item \code{deprecated}: Encode the booster into old customized binary format. +#' } +#' +#' Right now the default is \code{deprecated} but will be changed to \code{ubj} in upcoming release. #' #' @examples #' data(agaricus.train, package='xgboost') @@ -17,7 +25,8 @@ #' pred <- predict(bst, test$data) #' #' @export -xgb.save.raw <- function(model) { +xgb.save.raw <- function(model, raw_format = "deprecated") { handle <- xgb.get.handle(model) - .Call(XGBoosterModelToRaw_R, handle) + args <- list(format = raw_format) + .Call(XGBoosterSaveModelToRaw_R, handle, jsonlite::toJSON(args, auto_unbox = TRUE)) } diff --git a/R-package/demo/basic_walkthrough.R b/R-package/demo/basic_walkthrough.R index 6c7f79a03..31f79fb57 100644 --- a/R-package/demo/basic_walkthrough.R +++ b/R-package/demo/basic_walkthrough.R @@ -63,7 +63,7 @@ print(paste("sum(abs(pred2-pred))=", sum(abs(pred2 - pred)))) # save model to R's raw vector raw <- xgb.save.raw(bst) # load binary model to R -bst3 <- xgb.load(raw) +bst3 <- xgb.load.raw(raw) pred3 <- predict(bst3, test$data) # pred3 should be identical to pred print(paste("sum(abs(pred3-pred))=", sum(abs(pred3 - pred)))) diff --git a/R-package/src/init.c b/R-package/src/init.c index 9cfa1ac3f..2af072221 100644 --- a/R-package/src/init.c +++ b/R-package/src/init.c @@ -24,12 +24,12 @@ extern SEXP XGBoosterEvalOneIter_R(SEXP, SEXP, SEXP, SEXP); extern SEXP XGBoosterGetAttrNames_R(SEXP); extern SEXP XGBoosterGetAttr_R(SEXP, SEXP); extern SEXP XGBoosterLoadModelFromRaw_R(SEXP, SEXP); +extern SEXP XGBoosterSaveModelToRaw_R(SEXP handle, SEXP config); extern SEXP XGBoosterLoadModel_R(SEXP, SEXP); extern SEXP XGBoosterSaveJsonConfig_R(SEXP handle); extern SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value); extern SEXP XGBoosterSerializeToBuffer_R(SEXP handle); extern SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw); -extern SEXP XGBoosterModelToRaw_R(SEXP); extern SEXP XGBoosterPredict_R(SEXP, SEXP, SEXP, SEXP, SEXP); extern SEXP XGBoosterPredictFromDMatrix_R(SEXP, SEXP, SEXP); extern SEXP XGBoosterSaveModel_R(SEXP, SEXP); @@ -59,12 +59,12 @@ static const R_CallMethodDef CallEntries[] = { {"XGBoosterGetAttrNames_R", (DL_FUNC) &XGBoosterGetAttrNames_R, 1}, {"XGBoosterGetAttr_R", (DL_FUNC) &XGBoosterGetAttr_R, 2}, {"XGBoosterLoadModelFromRaw_R", (DL_FUNC) &XGBoosterLoadModelFromRaw_R, 2}, + {"XGBoosterSaveModelToRaw_R", (DL_FUNC) &XGBoosterSaveModelToRaw_R, 2}, {"XGBoosterLoadModel_R", (DL_FUNC) &XGBoosterLoadModel_R, 2}, {"XGBoosterSaveJsonConfig_R", (DL_FUNC) &XGBoosterSaveJsonConfig_R, 1}, {"XGBoosterLoadJsonConfig_R", (DL_FUNC) &XGBoosterLoadJsonConfig_R, 2}, {"XGBoosterSerializeToBuffer_R", (DL_FUNC) &XGBoosterSerializeToBuffer_R, 1}, {"XGBoosterUnserializeFromBuffer_R", (DL_FUNC) &XGBoosterUnserializeFromBuffer_R, 2}, - {"XGBoosterModelToRaw_R", (DL_FUNC) &XGBoosterModelToRaw_R, 1}, {"XGBoosterPredict_R", (DL_FUNC) &XGBoosterPredict_R, 5}, {"XGBoosterPredictFromDMatrix_R", (DL_FUNC) &XGBoosterPredictFromDMatrix_R, 3}, {"XGBoosterSaveModel_R", (DL_FUNC) &XGBoosterSaveModel_R, 2}, diff --git a/R-package/src/xgboost_R.cc b/R-package/src/xgboost_R.cc index 9921bb74b..5f7bd6c19 100644 --- a/R-package/src/xgboost_R.cc +++ b/R-package/src/xgboost_R.cc @@ -429,21 +429,6 @@ XGB_DLL SEXP XGBoosterSaveModel_R(SEXP handle, SEXP fname) { return R_NilValue; } -XGB_DLL SEXP XGBoosterModelToRaw_R(SEXP handle) { - SEXP ret; - R_API_BEGIN(); - bst_ulong olen; - const char *raw; - CHECK_CALL(XGBoosterGetModelRaw(R_ExternalPtrAddr(handle), &olen, &raw)); - ret = PROTECT(allocVector(RAWSXP, olen)); - if (olen != 0) { - memcpy(RAW(ret), raw, olen); - } - R_API_END(); - UNPROTECT(1); - return ret; -} - XGB_DLL SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) { R_API_BEGIN(); CHECK_CALL(XGBoosterLoadModelFromBuffer(R_ExternalPtrAddr(handle), @@ -453,6 +438,22 @@ XGB_DLL SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) { return R_NilValue; } +XGB_DLL SEXP XGBoosterSaveModelToRaw_R(SEXP handle, SEXP json_config) { + SEXP ret; + R_API_BEGIN(); + bst_ulong olen; + char const *c_json_config = CHAR(asChar(json_config)); + char const *raw; + CHECK_CALL(XGBoosterSaveModelToBuffer(R_ExternalPtrAddr(handle), c_json_config, &olen, &raw)) + ret = PROTECT(allocVector(RAWSXP, olen)); + if (olen != 0) { + std::memcpy(RAW(ret), raw, olen); + } + R_API_END(); + UNPROTECT(1); + return ret; +} + XGB_DLL SEXP XGBoosterSaveJsonConfig_R(SEXP handle) { const char* ret; R_API_BEGIN(); diff --git a/R-package/src/xgboost_R.h b/R-package/src/xgboost_R.h index 786514593..7d6edb648 100644 --- a/R-package/src/xgboost_R.h +++ b/R-package/src/xgboost_R.h @@ -209,11 +209,21 @@ XGB_DLL SEXP XGBoosterSaveModel_R(SEXP handle, SEXP fname); XGB_DLL SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw); /*! - * \brief save model into R's raw array + * \brief Save model into R's raw array + * * \param handle handle - * \return raw array + * \param json_config JSON encoded string storing parameters for the function. Following + * keys are expected in the JSON document: + * + * "format": str + * - json: Output booster will be encoded as JSON. + * - ubj: Output booster will be encoded as Univeral binary JSON. + * - deprecated: Output booster will be encoded as old custom binary format. Do now use + * this format except for compatibility reasons. + * + * \return Raw array */ -XGB_DLL SEXP XGBoosterModelToRaw_R(SEXP handle); +XGB_DLL SEXP XGBoosterSaveModelToRaw_R(SEXP handle, SEXP json_config); /*! * \brief Save internal parameters as a JSON string diff --git a/R-package/tests/testthat/test_io.R b/R-package/tests/testthat/test_io.R new file mode 100644 index 000000000..f4990352f --- /dev/null +++ b/R-package/tests/testthat/test_io.R @@ -0,0 +1,39 @@ +context("Test model IO.") +## some other tests are in test_basic.R +require(xgboost) +require(testthat) + +data(agaricus.train, package = "xgboost") +data(agaricus.test, package = "xgboost") +train <- agaricus.train +test <- agaricus.test + +test_that("load/save raw works", { + nrounds <- 8 + booster <- xgboost( + data = train$data, label = train$label, + nrounds = nrounds, objective = "binary:logistic" + ) + + json_bytes <- xgb.save.raw(booster, raw_format = "json") + ubj_bytes <- xgb.save.raw(booster, raw_format = "ubj") + old_bytes <- xgb.save.raw(booster, raw_format = "deprecated") + + from_json <- xgb.load.raw(json_bytes) + from_ubj <- xgb.load.raw(ubj_bytes) + + ## FIXME(jiamingy): Should we include these 3 lines into `xgb.load.raw`? + from_json <- list(handle = from_json, raw = NULL) + class(from_json) <- "xgb.Booster" + from_json <- xgb.Booster.complete(from_json, saveraw = TRUE) + + from_ubj <- list(handle = from_ubj, raw = NULL) + class(from_ubj) <- "xgb.Booster" + from_ubj <- xgb.Booster.complete(from_ubj, saveraw = TRUE) + + json2old <- xgb.save.raw(from_json, raw_format = "deprecated") + ubj2old <- xgb.save.raw(from_ubj, raw_format = "deprecated") + + expect_equal(json2old, ubj2old) + expect_equal(json2old, old_bytes) +})