[R] Implement new save raw in R. (#7571)

This commit is contained in:
Jiaming Yuan 2022-01-22 20:55:47 +08:00 committed by GitHub
parent ef4dae4c0e
commit d262503781
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 90 additions and 24 deletions

View File

@ -5,7 +5,7 @@
#' @param modelfile the name of the binary input file. #' @param modelfile the name of the binary input file.
#' #'
#' @details #' @details
#' The input file is expected to contain a model saved in an xgboost-internal binary format #' The input file is expected to contain a model saved in an xgboost model format
#' using either \code{\link{xgb.save}} or \code{\link{cb.save.model}} in R, or using some #' using either \code{\link{xgb.save}} or \code{\link{cb.save.model}} in R, or using some
#' appropriate methods from other xgboost interfaces. E.g., a model trained in Python and #' appropriate methods from other xgboost interfaces. E.g., a model trained in Python and
#' saved from there in xgboost format, could be loaded from R. #' saved from there in xgboost format, could be loaded from R.
@ -38,6 +38,13 @@ xgb.load <- function(modelfile) {
handle <- xgb.Booster.handle(modelfile = modelfile) handle <- xgb.Booster.handle(modelfile = modelfile)
# re-use modelfile if it is raw so we do not need to serialize # re-use modelfile if it is raw so we do not need to serialize
if (typeof(modelfile) == "raw") { if (typeof(modelfile) == "raw") {
warning(
paste(
"The support for loading raw booster with `xgb.load` will be ",
"discontinued in upcoming release. Use `xgb.load.raw` or",
" `xgb.unserialize` instead. "
)
)
bst <- xgb.handleToBooster(handle, modelfile) bst <- xgb.handleToBooster(handle, modelfile)
} else { } else {
bst <- xgb.handleToBooster(handle, NULL) bst <- xgb.handleToBooster(handle, NULL)

View File

@ -4,6 +4,14 @@
#' Save xgboost model from xgboost or xgb.train #' Save xgboost model from xgboost or xgb.train
#' #'
#' @param model the model object. #' @param model the model object.
#' @param raw_format The format for encoding the booster. Available options are
#' \itemize{
#' \item \code{json}: Encode the booster into JSON text document.
#' \item \code{ubj}: Encode the booster into Universal Binary JSON.
#' \item \code{deprecated}: Encode the booster into old customized binary format.
#' }
#'
#' Right now the default is \code{deprecated} but will be changed to \code{ubj} in upcoming release.
#' #'
#' @examples #' @examples
#' data(agaricus.train, package='xgboost') #' data(agaricus.train, package='xgboost')
@ -17,7 +25,8 @@
#' pred <- predict(bst, test$data) #' pred <- predict(bst, test$data)
#' #'
#' @export #' @export
xgb.save.raw <- function(model) { xgb.save.raw <- function(model, raw_format = "deprecated") {
handle <- xgb.get.handle(model) handle <- xgb.get.handle(model)
.Call(XGBoosterModelToRaw_R, handle) args <- list(format = raw_format)
.Call(XGBoosterSaveModelToRaw_R, handle, jsonlite::toJSON(args, auto_unbox = TRUE))
} }

View File

@ -63,7 +63,7 @@ print(paste("sum(abs(pred2-pred))=", sum(abs(pred2 - pred))))
# save model to R's raw vector # save model to R's raw vector
raw <- xgb.save.raw(bst) raw <- xgb.save.raw(bst)
# load binary model to R # load binary model to R
bst3 <- xgb.load(raw) bst3 <- xgb.load.raw(raw)
pred3 <- predict(bst3, test$data) pred3 <- predict(bst3, test$data)
# pred3 should be identical to pred # pred3 should be identical to pred
print(paste("sum(abs(pred3-pred))=", sum(abs(pred3 - pred)))) print(paste("sum(abs(pred3-pred))=", sum(abs(pred3 - pred))))

View File

@ -24,12 +24,12 @@ extern SEXP XGBoosterEvalOneIter_R(SEXP, SEXP, SEXP, SEXP);
extern SEXP XGBoosterGetAttrNames_R(SEXP); extern SEXP XGBoosterGetAttrNames_R(SEXP);
extern SEXP XGBoosterGetAttr_R(SEXP, SEXP); extern SEXP XGBoosterGetAttr_R(SEXP, SEXP);
extern SEXP XGBoosterLoadModelFromRaw_R(SEXP, SEXP); extern SEXP XGBoosterLoadModelFromRaw_R(SEXP, SEXP);
extern SEXP XGBoosterSaveModelToRaw_R(SEXP handle, SEXP config);
extern SEXP XGBoosterLoadModel_R(SEXP, SEXP); extern SEXP XGBoosterLoadModel_R(SEXP, SEXP);
extern SEXP XGBoosterSaveJsonConfig_R(SEXP handle); extern SEXP XGBoosterSaveJsonConfig_R(SEXP handle);
extern SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value); extern SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value);
extern SEXP XGBoosterSerializeToBuffer_R(SEXP handle); extern SEXP XGBoosterSerializeToBuffer_R(SEXP handle);
extern SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw); extern SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw);
extern SEXP XGBoosterModelToRaw_R(SEXP);
extern SEXP XGBoosterPredict_R(SEXP, SEXP, SEXP, SEXP, SEXP); extern SEXP XGBoosterPredict_R(SEXP, SEXP, SEXP, SEXP, SEXP);
extern SEXP XGBoosterPredictFromDMatrix_R(SEXP, SEXP, SEXP); extern SEXP XGBoosterPredictFromDMatrix_R(SEXP, SEXP, SEXP);
extern SEXP XGBoosterSaveModel_R(SEXP, SEXP); extern SEXP XGBoosterSaveModel_R(SEXP, SEXP);
@ -59,12 +59,12 @@ static const R_CallMethodDef CallEntries[] = {
{"XGBoosterGetAttrNames_R", (DL_FUNC) &XGBoosterGetAttrNames_R, 1}, {"XGBoosterGetAttrNames_R", (DL_FUNC) &XGBoosterGetAttrNames_R, 1},
{"XGBoosterGetAttr_R", (DL_FUNC) &XGBoosterGetAttr_R, 2}, {"XGBoosterGetAttr_R", (DL_FUNC) &XGBoosterGetAttr_R, 2},
{"XGBoosterLoadModelFromRaw_R", (DL_FUNC) &XGBoosterLoadModelFromRaw_R, 2}, {"XGBoosterLoadModelFromRaw_R", (DL_FUNC) &XGBoosterLoadModelFromRaw_R, 2},
{"XGBoosterSaveModelToRaw_R", (DL_FUNC) &XGBoosterSaveModelToRaw_R, 2},
{"XGBoosterLoadModel_R", (DL_FUNC) &XGBoosterLoadModel_R, 2}, {"XGBoosterLoadModel_R", (DL_FUNC) &XGBoosterLoadModel_R, 2},
{"XGBoosterSaveJsonConfig_R", (DL_FUNC) &XGBoosterSaveJsonConfig_R, 1}, {"XGBoosterSaveJsonConfig_R", (DL_FUNC) &XGBoosterSaveJsonConfig_R, 1},
{"XGBoosterLoadJsonConfig_R", (DL_FUNC) &XGBoosterLoadJsonConfig_R, 2}, {"XGBoosterLoadJsonConfig_R", (DL_FUNC) &XGBoosterLoadJsonConfig_R, 2},
{"XGBoosterSerializeToBuffer_R", (DL_FUNC) &XGBoosterSerializeToBuffer_R, 1}, {"XGBoosterSerializeToBuffer_R", (DL_FUNC) &XGBoosterSerializeToBuffer_R, 1},
{"XGBoosterUnserializeFromBuffer_R", (DL_FUNC) &XGBoosterUnserializeFromBuffer_R, 2}, {"XGBoosterUnserializeFromBuffer_R", (DL_FUNC) &XGBoosterUnserializeFromBuffer_R, 2},
{"XGBoosterModelToRaw_R", (DL_FUNC) &XGBoosterModelToRaw_R, 1},
{"XGBoosterPredict_R", (DL_FUNC) &XGBoosterPredict_R, 5}, {"XGBoosterPredict_R", (DL_FUNC) &XGBoosterPredict_R, 5},
{"XGBoosterPredictFromDMatrix_R", (DL_FUNC) &XGBoosterPredictFromDMatrix_R, 3}, {"XGBoosterPredictFromDMatrix_R", (DL_FUNC) &XGBoosterPredictFromDMatrix_R, 3},
{"XGBoosterSaveModel_R", (DL_FUNC) &XGBoosterSaveModel_R, 2}, {"XGBoosterSaveModel_R", (DL_FUNC) &XGBoosterSaveModel_R, 2},

View File

@ -429,21 +429,6 @@ XGB_DLL SEXP XGBoosterSaveModel_R(SEXP handle, SEXP fname) {
return R_NilValue; return R_NilValue;
} }
XGB_DLL SEXP XGBoosterModelToRaw_R(SEXP handle) {
SEXP ret;
R_API_BEGIN();
bst_ulong olen;
const char *raw;
CHECK_CALL(XGBoosterGetModelRaw(R_ExternalPtrAddr(handle), &olen, &raw));
ret = PROTECT(allocVector(RAWSXP, olen));
if (olen != 0) {
memcpy(RAW(ret), raw, olen);
}
R_API_END();
UNPROTECT(1);
return ret;
}
XGB_DLL SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) { XGB_DLL SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) {
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(XGBoosterLoadModelFromBuffer(R_ExternalPtrAddr(handle), CHECK_CALL(XGBoosterLoadModelFromBuffer(R_ExternalPtrAddr(handle),
@ -453,6 +438,22 @@ XGB_DLL SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) {
return R_NilValue; return R_NilValue;
} }
XGB_DLL SEXP XGBoosterSaveModelToRaw_R(SEXP handle, SEXP json_config) {
SEXP ret;
R_API_BEGIN();
bst_ulong olen;
char const *c_json_config = CHAR(asChar(json_config));
char const *raw;
CHECK_CALL(XGBoosterSaveModelToBuffer(R_ExternalPtrAddr(handle), c_json_config, &olen, &raw))
ret = PROTECT(allocVector(RAWSXP, olen));
if (olen != 0) {
std::memcpy(RAW(ret), raw, olen);
}
R_API_END();
UNPROTECT(1);
return ret;
}
XGB_DLL SEXP XGBoosterSaveJsonConfig_R(SEXP handle) { XGB_DLL SEXP XGBoosterSaveJsonConfig_R(SEXP handle) {
const char* ret; const char* ret;
R_API_BEGIN(); R_API_BEGIN();

View File

@ -209,11 +209,21 @@ XGB_DLL SEXP XGBoosterSaveModel_R(SEXP handle, SEXP fname);
XGB_DLL SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw); XGB_DLL SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw);
/*! /*!
* \brief save model into R's raw array * \brief Save model into R's raw array
*
* \param handle handle * \param handle handle
* \return raw array * \param json_config JSON encoded string storing parameters for the function. Following
* keys are expected in the JSON document:
*
* "format": str
* - json: Output booster will be encoded as JSON.
* - ubj: Output booster will be encoded as Univeral binary JSON.
* - deprecated: Output booster will be encoded as old custom binary format. Do now use
* this format except for compatibility reasons.
*
* \return Raw array
*/ */
XGB_DLL SEXP XGBoosterModelToRaw_R(SEXP handle); XGB_DLL SEXP XGBoosterSaveModelToRaw_R(SEXP handle, SEXP json_config);
/*! /*!
* \brief Save internal parameters as a JSON string * \brief Save internal parameters as a JSON string

View File

@ -0,0 +1,39 @@
context("Test model IO.")
## some other tests are in test_basic.R
require(xgboost)
require(testthat)
data(agaricus.train, package = "xgboost")
data(agaricus.test, package = "xgboost")
train <- agaricus.train
test <- agaricus.test
test_that("load/save raw works", {
nrounds <- 8
booster <- xgboost(
data = train$data, label = train$label,
nrounds = nrounds, objective = "binary:logistic"
)
json_bytes <- xgb.save.raw(booster, raw_format = "json")
ubj_bytes <- xgb.save.raw(booster, raw_format = "ubj")
old_bytes <- xgb.save.raw(booster, raw_format = "deprecated")
from_json <- xgb.load.raw(json_bytes)
from_ubj <- xgb.load.raw(ubj_bytes)
## FIXME(jiamingy): Should we include these 3 lines into `xgb.load.raw`?
from_json <- list(handle = from_json, raw = NULL)
class(from_json) <- "xgb.Booster"
from_json <- xgb.Booster.complete(from_json, saveraw = TRUE)
from_ubj <- list(handle = from_ubj, raw = NULL)
class(from_ubj) <- "xgb.Booster"
from_ubj <- xgb.Booster.complete(from_ubj, saveraw = TRUE)
json2old <- xgb.save.raw(from_json, raw_format = "deprecated")
ubj2old <- xgb.save.raw(from_ubj, raw_format = "deprecated")
expect_equal(json2old, ubj2old)
expect_equal(json2old, old_bytes)
})