[R] Implement new save raw in R. (#7571)
This commit is contained in:
parent
ef4dae4c0e
commit
d262503781
@ -5,7 +5,7 @@
|
||||
#' @param modelfile the name of the binary input file.
|
||||
#'
|
||||
#' @details
|
||||
#' The input file is expected to contain a model saved in an xgboost-internal binary format
|
||||
#' The input file is expected to contain a model saved in an xgboost model format
|
||||
#' using either \code{\link{xgb.save}} or \code{\link{cb.save.model}} in R, or using some
|
||||
#' appropriate methods from other xgboost interfaces. E.g., a model trained in Python and
|
||||
#' saved from there in xgboost format, could be loaded from R.
|
||||
@ -38,6 +38,13 @@ xgb.load <- function(modelfile) {
|
||||
handle <- xgb.Booster.handle(modelfile = modelfile)
|
||||
# re-use modelfile if it is raw so we do not need to serialize
|
||||
if (typeof(modelfile) == "raw") {
|
||||
warning(
|
||||
paste(
|
||||
"The support for loading raw booster with `xgb.load` will be ",
|
||||
"discontinued in upcoming release. Use `xgb.load.raw` or",
|
||||
" `xgb.unserialize` instead. "
|
||||
)
|
||||
)
|
||||
bst <- xgb.handleToBooster(handle, modelfile)
|
||||
} else {
|
||||
bst <- xgb.handleToBooster(handle, NULL)
|
||||
|
||||
@ -4,6 +4,14 @@
|
||||
#' Save xgboost model from xgboost or xgb.train
|
||||
#'
|
||||
#' @param model the model object.
|
||||
#' @param raw_format The format for encoding the booster. Available options are
|
||||
#' \itemize{
|
||||
#' \item \code{json}: Encode the booster into JSON text document.
|
||||
#' \item \code{ubj}: Encode the booster into Universal Binary JSON.
|
||||
#' \item \code{deprecated}: Encode the booster into old customized binary format.
|
||||
#' }
|
||||
#'
|
||||
#' Right now the default is \code{deprecated} but will be changed to \code{ubj} in upcoming release.
|
||||
#'
|
||||
#' @examples
|
||||
#' data(agaricus.train, package='xgboost')
|
||||
@ -17,7 +25,8 @@
|
||||
#' pred <- predict(bst, test$data)
|
||||
#'
|
||||
#' @export
|
||||
xgb.save.raw <- function(model) {
|
||||
xgb.save.raw <- function(model, raw_format = "deprecated") {
|
||||
handle <- xgb.get.handle(model)
|
||||
.Call(XGBoosterModelToRaw_R, handle)
|
||||
args <- list(format = raw_format)
|
||||
.Call(XGBoosterSaveModelToRaw_R, handle, jsonlite::toJSON(args, auto_unbox = TRUE))
|
||||
}
|
||||
|
||||
@ -63,7 +63,7 @@ print(paste("sum(abs(pred2-pred))=", sum(abs(pred2 - pred))))
|
||||
# save model to R's raw vector
|
||||
raw <- xgb.save.raw(bst)
|
||||
# load binary model to R
|
||||
bst3 <- xgb.load(raw)
|
||||
bst3 <- xgb.load.raw(raw)
|
||||
pred3 <- predict(bst3, test$data)
|
||||
# pred3 should be identical to pred
|
||||
print(paste("sum(abs(pred3-pred))=", sum(abs(pred3 - pred))))
|
||||
|
||||
@ -24,12 +24,12 @@ extern SEXP XGBoosterEvalOneIter_R(SEXP, SEXP, SEXP, SEXP);
|
||||
extern SEXP XGBoosterGetAttrNames_R(SEXP);
|
||||
extern SEXP XGBoosterGetAttr_R(SEXP, SEXP);
|
||||
extern SEXP XGBoosterLoadModelFromRaw_R(SEXP, SEXP);
|
||||
extern SEXP XGBoosterSaveModelToRaw_R(SEXP handle, SEXP config);
|
||||
extern SEXP XGBoosterLoadModel_R(SEXP, SEXP);
|
||||
extern SEXP XGBoosterSaveJsonConfig_R(SEXP handle);
|
||||
extern SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value);
|
||||
extern SEXP XGBoosterSerializeToBuffer_R(SEXP handle);
|
||||
extern SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw);
|
||||
extern SEXP XGBoosterModelToRaw_R(SEXP);
|
||||
extern SEXP XGBoosterPredict_R(SEXP, SEXP, SEXP, SEXP, SEXP);
|
||||
extern SEXP XGBoosterPredictFromDMatrix_R(SEXP, SEXP, SEXP);
|
||||
extern SEXP XGBoosterSaveModel_R(SEXP, SEXP);
|
||||
@ -59,12 +59,12 @@ static const R_CallMethodDef CallEntries[] = {
|
||||
{"XGBoosterGetAttrNames_R", (DL_FUNC) &XGBoosterGetAttrNames_R, 1},
|
||||
{"XGBoosterGetAttr_R", (DL_FUNC) &XGBoosterGetAttr_R, 2},
|
||||
{"XGBoosterLoadModelFromRaw_R", (DL_FUNC) &XGBoosterLoadModelFromRaw_R, 2},
|
||||
{"XGBoosterSaveModelToRaw_R", (DL_FUNC) &XGBoosterSaveModelToRaw_R, 2},
|
||||
{"XGBoosterLoadModel_R", (DL_FUNC) &XGBoosterLoadModel_R, 2},
|
||||
{"XGBoosterSaveJsonConfig_R", (DL_FUNC) &XGBoosterSaveJsonConfig_R, 1},
|
||||
{"XGBoosterLoadJsonConfig_R", (DL_FUNC) &XGBoosterLoadJsonConfig_R, 2},
|
||||
{"XGBoosterSerializeToBuffer_R", (DL_FUNC) &XGBoosterSerializeToBuffer_R, 1},
|
||||
{"XGBoosterUnserializeFromBuffer_R", (DL_FUNC) &XGBoosterUnserializeFromBuffer_R, 2},
|
||||
{"XGBoosterModelToRaw_R", (DL_FUNC) &XGBoosterModelToRaw_R, 1},
|
||||
{"XGBoosterPredict_R", (DL_FUNC) &XGBoosterPredict_R, 5},
|
||||
{"XGBoosterPredictFromDMatrix_R", (DL_FUNC) &XGBoosterPredictFromDMatrix_R, 3},
|
||||
{"XGBoosterSaveModel_R", (DL_FUNC) &XGBoosterSaveModel_R, 2},
|
||||
|
||||
@ -429,21 +429,6 @@ XGB_DLL SEXP XGBoosterSaveModel_R(SEXP handle, SEXP fname) {
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
XGB_DLL SEXP XGBoosterModelToRaw_R(SEXP handle) {
|
||||
SEXP ret;
|
||||
R_API_BEGIN();
|
||||
bst_ulong olen;
|
||||
const char *raw;
|
||||
CHECK_CALL(XGBoosterGetModelRaw(R_ExternalPtrAddr(handle), &olen, &raw));
|
||||
ret = PROTECT(allocVector(RAWSXP, olen));
|
||||
if (olen != 0) {
|
||||
memcpy(RAW(ret), raw, olen);
|
||||
}
|
||||
R_API_END();
|
||||
UNPROTECT(1);
|
||||
return ret;
|
||||
}
|
||||
|
||||
XGB_DLL SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) {
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(XGBoosterLoadModelFromBuffer(R_ExternalPtrAddr(handle),
|
||||
@ -453,6 +438,22 @@ XGB_DLL SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) {
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
XGB_DLL SEXP XGBoosterSaveModelToRaw_R(SEXP handle, SEXP json_config) {
|
||||
SEXP ret;
|
||||
R_API_BEGIN();
|
||||
bst_ulong olen;
|
||||
char const *c_json_config = CHAR(asChar(json_config));
|
||||
char const *raw;
|
||||
CHECK_CALL(XGBoosterSaveModelToBuffer(R_ExternalPtrAddr(handle), c_json_config, &olen, &raw))
|
||||
ret = PROTECT(allocVector(RAWSXP, olen));
|
||||
if (olen != 0) {
|
||||
std::memcpy(RAW(ret), raw, olen);
|
||||
}
|
||||
R_API_END();
|
||||
UNPROTECT(1);
|
||||
return ret;
|
||||
}
|
||||
|
||||
XGB_DLL SEXP XGBoosterSaveJsonConfig_R(SEXP handle) {
|
||||
const char* ret;
|
||||
R_API_BEGIN();
|
||||
|
||||
@ -209,11 +209,21 @@ XGB_DLL SEXP XGBoosterSaveModel_R(SEXP handle, SEXP fname);
|
||||
XGB_DLL SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw);
|
||||
|
||||
/*!
|
||||
* \brief save model into R's raw array
|
||||
* \brief Save model into R's raw array
|
||||
*
|
||||
* \param handle handle
|
||||
* \return raw array
|
||||
* \param json_config JSON encoded string storing parameters for the function. Following
|
||||
* keys are expected in the JSON document:
|
||||
*
|
||||
* "format": str
|
||||
* - json: Output booster will be encoded as JSON.
|
||||
* - ubj: Output booster will be encoded as Univeral binary JSON.
|
||||
* - deprecated: Output booster will be encoded as old custom binary format. Do now use
|
||||
* this format except for compatibility reasons.
|
||||
*
|
||||
* \return Raw array
|
||||
*/
|
||||
XGB_DLL SEXP XGBoosterModelToRaw_R(SEXP handle);
|
||||
XGB_DLL SEXP XGBoosterSaveModelToRaw_R(SEXP handle, SEXP json_config);
|
||||
|
||||
/*!
|
||||
* \brief Save internal parameters as a JSON string
|
||||
|
||||
39
R-package/tests/testthat/test_io.R
Normal file
39
R-package/tests/testthat/test_io.R
Normal file
@ -0,0 +1,39 @@
|
||||
context("Test model IO.")
|
||||
## some other tests are in test_basic.R
|
||||
require(xgboost)
|
||||
require(testthat)
|
||||
|
||||
data(agaricus.train, package = "xgboost")
|
||||
data(agaricus.test, package = "xgboost")
|
||||
train <- agaricus.train
|
||||
test <- agaricus.test
|
||||
|
||||
test_that("load/save raw works", {
|
||||
nrounds <- 8
|
||||
booster <- xgboost(
|
||||
data = train$data, label = train$label,
|
||||
nrounds = nrounds, objective = "binary:logistic"
|
||||
)
|
||||
|
||||
json_bytes <- xgb.save.raw(booster, raw_format = "json")
|
||||
ubj_bytes <- xgb.save.raw(booster, raw_format = "ubj")
|
||||
old_bytes <- xgb.save.raw(booster, raw_format = "deprecated")
|
||||
|
||||
from_json <- xgb.load.raw(json_bytes)
|
||||
from_ubj <- xgb.load.raw(ubj_bytes)
|
||||
|
||||
## FIXME(jiamingy): Should we include these 3 lines into `xgb.load.raw`?
|
||||
from_json <- list(handle = from_json, raw = NULL)
|
||||
class(from_json) <- "xgb.Booster"
|
||||
from_json <- xgb.Booster.complete(from_json, saveraw = TRUE)
|
||||
|
||||
from_ubj <- list(handle = from_ubj, raw = NULL)
|
||||
class(from_ubj) <- "xgb.Booster"
|
||||
from_ubj <- xgb.Booster.complete(from_ubj, saveraw = TRUE)
|
||||
|
||||
json2old <- xgb.save.raw(from_json, raw_format = "deprecated")
|
||||
ubj2old <- xgb.save.raw(from_ubj, raw_format = "deprecated")
|
||||
|
||||
expect_equal(json2old, ubj2old)
|
||||
expect_equal(json2old, old_bytes)
|
||||
})
|
||||
Loading…
x
Reference in New Issue
Block a user