[R] Implement new save raw in R. (#7571)
This commit is contained in:
parent
ef4dae4c0e
commit
d262503781
@ -5,7 +5,7 @@
|
|||||||
#' @param modelfile the name of the binary input file.
|
#' @param modelfile the name of the binary input file.
|
||||||
#'
|
#'
|
||||||
#' @details
|
#' @details
|
||||||
#' The input file is expected to contain a model saved in an xgboost-internal binary format
|
#' The input file is expected to contain a model saved in an xgboost model format
|
||||||
#' using either \code{\link{xgb.save}} or \code{\link{cb.save.model}} in R, or using some
|
#' using either \code{\link{xgb.save}} or \code{\link{cb.save.model}} in R, or using some
|
||||||
#' appropriate methods from other xgboost interfaces. E.g., a model trained in Python and
|
#' appropriate methods from other xgboost interfaces. E.g., a model trained in Python and
|
||||||
#' saved from there in xgboost format, could be loaded from R.
|
#' saved from there in xgboost format, could be loaded from R.
|
||||||
@ -38,6 +38,13 @@ xgb.load <- function(modelfile) {
|
|||||||
handle <- xgb.Booster.handle(modelfile = modelfile)
|
handle <- xgb.Booster.handle(modelfile = modelfile)
|
||||||
# re-use modelfile if it is raw so we do not need to serialize
|
# re-use modelfile if it is raw so we do not need to serialize
|
||||||
if (typeof(modelfile) == "raw") {
|
if (typeof(modelfile) == "raw") {
|
||||||
|
warning(
|
||||||
|
paste(
|
||||||
|
"The support for loading raw booster with `xgb.load` will be ",
|
||||||
|
"discontinued in upcoming release. Use `xgb.load.raw` or",
|
||||||
|
" `xgb.unserialize` instead. "
|
||||||
|
)
|
||||||
|
)
|
||||||
bst <- xgb.handleToBooster(handle, modelfile)
|
bst <- xgb.handleToBooster(handle, modelfile)
|
||||||
} else {
|
} else {
|
||||||
bst <- xgb.handleToBooster(handle, NULL)
|
bst <- xgb.handleToBooster(handle, NULL)
|
||||||
|
|||||||
@ -4,6 +4,14 @@
|
|||||||
#' Save xgboost model from xgboost or xgb.train
|
#' Save xgboost model from xgboost or xgb.train
|
||||||
#'
|
#'
|
||||||
#' @param model the model object.
|
#' @param model the model object.
|
||||||
|
#' @param raw_format The format for encoding the booster. Available options are
|
||||||
|
#' \itemize{
|
||||||
|
#' \item \code{json}: Encode the booster into JSON text document.
|
||||||
|
#' \item \code{ubj}: Encode the booster into Universal Binary JSON.
|
||||||
|
#' \item \code{deprecated}: Encode the booster into old customized binary format.
|
||||||
|
#' }
|
||||||
|
#'
|
||||||
|
#' Right now the default is \code{deprecated} but will be changed to \code{ubj} in upcoming release.
|
||||||
#'
|
#'
|
||||||
#' @examples
|
#' @examples
|
||||||
#' data(agaricus.train, package='xgboost')
|
#' data(agaricus.train, package='xgboost')
|
||||||
@ -17,7 +25,8 @@
|
|||||||
#' pred <- predict(bst, test$data)
|
#' pred <- predict(bst, test$data)
|
||||||
#'
|
#'
|
||||||
#' @export
|
#' @export
|
||||||
xgb.save.raw <- function(model) {
|
xgb.save.raw <- function(model, raw_format = "deprecated") {
|
||||||
handle <- xgb.get.handle(model)
|
handle <- xgb.get.handle(model)
|
||||||
.Call(XGBoosterModelToRaw_R, handle)
|
args <- list(format = raw_format)
|
||||||
|
.Call(XGBoosterSaveModelToRaw_R, handle, jsonlite::toJSON(args, auto_unbox = TRUE))
|
||||||
}
|
}
|
||||||
|
|||||||
@ -63,7 +63,7 @@ print(paste("sum(abs(pred2-pred))=", sum(abs(pred2 - pred))))
|
|||||||
# save model to R's raw vector
|
# save model to R's raw vector
|
||||||
raw <- xgb.save.raw(bst)
|
raw <- xgb.save.raw(bst)
|
||||||
# load binary model to R
|
# load binary model to R
|
||||||
bst3 <- xgb.load(raw)
|
bst3 <- xgb.load.raw(raw)
|
||||||
pred3 <- predict(bst3, test$data)
|
pred3 <- predict(bst3, test$data)
|
||||||
# pred3 should be identical to pred
|
# pred3 should be identical to pred
|
||||||
print(paste("sum(abs(pred3-pred))=", sum(abs(pred3 - pred))))
|
print(paste("sum(abs(pred3-pred))=", sum(abs(pred3 - pred))))
|
||||||
|
|||||||
@ -24,12 +24,12 @@ extern SEXP XGBoosterEvalOneIter_R(SEXP, SEXP, SEXP, SEXP);
|
|||||||
extern SEXP XGBoosterGetAttrNames_R(SEXP);
|
extern SEXP XGBoosterGetAttrNames_R(SEXP);
|
||||||
extern SEXP XGBoosterGetAttr_R(SEXP, SEXP);
|
extern SEXP XGBoosterGetAttr_R(SEXP, SEXP);
|
||||||
extern SEXP XGBoosterLoadModelFromRaw_R(SEXP, SEXP);
|
extern SEXP XGBoosterLoadModelFromRaw_R(SEXP, SEXP);
|
||||||
|
extern SEXP XGBoosterSaveModelToRaw_R(SEXP handle, SEXP config);
|
||||||
extern SEXP XGBoosterLoadModel_R(SEXP, SEXP);
|
extern SEXP XGBoosterLoadModel_R(SEXP, SEXP);
|
||||||
extern SEXP XGBoosterSaveJsonConfig_R(SEXP handle);
|
extern SEXP XGBoosterSaveJsonConfig_R(SEXP handle);
|
||||||
extern SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value);
|
extern SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value);
|
||||||
extern SEXP XGBoosterSerializeToBuffer_R(SEXP handle);
|
extern SEXP XGBoosterSerializeToBuffer_R(SEXP handle);
|
||||||
extern SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw);
|
extern SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw);
|
||||||
extern SEXP XGBoosterModelToRaw_R(SEXP);
|
|
||||||
extern SEXP XGBoosterPredict_R(SEXP, SEXP, SEXP, SEXP, SEXP);
|
extern SEXP XGBoosterPredict_R(SEXP, SEXP, SEXP, SEXP, SEXP);
|
||||||
extern SEXP XGBoosterPredictFromDMatrix_R(SEXP, SEXP, SEXP);
|
extern SEXP XGBoosterPredictFromDMatrix_R(SEXP, SEXP, SEXP);
|
||||||
extern SEXP XGBoosterSaveModel_R(SEXP, SEXP);
|
extern SEXP XGBoosterSaveModel_R(SEXP, SEXP);
|
||||||
@ -59,12 +59,12 @@ static const R_CallMethodDef CallEntries[] = {
|
|||||||
{"XGBoosterGetAttrNames_R", (DL_FUNC) &XGBoosterGetAttrNames_R, 1},
|
{"XGBoosterGetAttrNames_R", (DL_FUNC) &XGBoosterGetAttrNames_R, 1},
|
||||||
{"XGBoosterGetAttr_R", (DL_FUNC) &XGBoosterGetAttr_R, 2},
|
{"XGBoosterGetAttr_R", (DL_FUNC) &XGBoosterGetAttr_R, 2},
|
||||||
{"XGBoosterLoadModelFromRaw_R", (DL_FUNC) &XGBoosterLoadModelFromRaw_R, 2},
|
{"XGBoosterLoadModelFromRaw_R", (DL_FUNC) &XGBoosterLoadModelFromRaw_R, 2},
|
||||||
|
{"XGBoosterSaveModelToRaw_R", (DL_FUNC) &XGBoosterSaveModelToRaw_R, 2},
|
||||||
{"XGBoosterLoadModel_R", (DL_FUNC) &XGBoosterLoadModel_R, 2},
|
{"XGBoosterLoadModel_R", (DL_FUNC) &XGBoosterLoadModel_R, 2},
|
||||||
{"XGBoosterSaveJsonConfig_R", (DL_FUNC) &XGBoosterSaveJsonConfig_R, 1},
|
{"XGBoosterSaveJsonConfig_R", (DL_FUNC) &XGBoosterSaveJsonConfig_R, 1},
|
||||||
{"XGBoosterLoadJsonConfig_R", (DL_FUNC) &XGBoosterLoadJsonConfig_R, 2},
|
{"XGBoosterLoadJsonConfig_R", (DL_FUNC) &XGBoosterLoadJsonConfig_R, 2},
|
||||||
{"XGBoosterSerializeToBuffer_R", (DL_FUNC) &XGBoosterSerializeToBuffer_R, 1},
|
{"XGBoosterSerializeToBuffer_R", (DL_FUNC) &XGBoosterSerializeToBuffer_R, 1},
|
||||||
{"XGBoosterUnserializeFromBuffer_R", (DL_FUNC) &XGBoosterUnserializeFromBuffer_R, 2},
|
{"XGBoosterUnserializeFromBuffer_R", (DL_FUNC) &XGBoosterUnserializeFromBuffer_R, 2},
|
||||||
{"XGBoosterModelToRaw_R", (DL_FUNC) &XGBoosterModelToRaw_R, 1},
|
|
||||||
{"XGBoosterPredict_R", (DL_FUNC) &XGBoosterPredict_R, 5},
|
{"XGBoosterPredict_R", (DL_FUNC) &XGBoosterPredict_R, 5},
|
||||||
{"XGBoosterPredictFromDMatrix_R", (DL_FUNC) &XGBoosterPredictFromDMatrix_R, 3},
|
{"XGBoosterPredictFromDMatrix_R", (DL_FUNC) &XGBoosterPredictFromDMatrix_R, 3},
|
||||||
{"XGBoosterSaveModel_R", (DL_FUNC) &XGBoosterSaveModel_R, 2},
|
{"XGBoosterSaveModel_R", (DL_FUNC) &XGBoosterSaveModel_R, 2},
|
||||||
|
|||||||
@ -429,21 +429,6 @@ XGB_DLL SEXP XGBoosterSaveModel_R(SEXP handle, SEXP fname) {
|
|||||||
return R_NilValue;
|
return R_NilValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL SEXP XGBoosterModelToRaw_R(SEXP handle) {
|
|
||||||
SEXP ret;
|
|
||||||
R_API_BEGIN();
|
|
||||||
bst_ulong olen;
|
|
||||||
const char *raw;
|
|
||||||
CHECK_CALL(XGBoosterGetModelRaw(R_ExternalPtrAddr(handle), &olen, &raw));
|
|
||||||
ret = PROTECT(allocVector(RAWSXP, olen));
|
|
||||||
if (olen != 0) {
|
|
||||||
memcpy(RAW(ret), raw, olen);
|
|
||||||
}
|
|
||||||
R_API_END();
|
|
||||||
UNPROTECT(1);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
XGB_DLL SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) {
|
XGB_DLL SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) {
|
||||||
R_API_BEGIN();
|
R_API_BEGIN();
|
||||||
CHECK_CALL(XGBoosterLoadModelFromBuffer(R_ExternalPtrAddr(handle),
|
CHECK_CALL(XGBoosterLoadModelFromBuffer(R_ExternalPtrAddr(handle),
|
||||||
@ -453,6 +438,22 @@ XGB_DLL SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) {
|
|||||||
return R_NilValue;
|
return R_NilValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XGB_DLL SEXP XGBoosterSaveModelToRaw_R(SEXP handle, SEXP json_config) {
|
||||||
|
SEXP ret;
|
||||||
|
R_API_BEGIN();
|
||||||
|
bst_ulong olen;
|
||||||
|
char const *c_json_config = CHAR(asChar(json_config));
|
||||||
|
char const *raw;
|
||||||
|
CHECK_CALL(XGBoosterSaveModelToBuffer(R_ExternalPtrAddr(handle), c_json_config, &olen, &raw))
|
||||||
|
ret = PROTECT(allocVector(RAWSXP, olen));
|
||||||
|
if (olen != 0) {
|
||||||
|
std::memcpy(RAW(ret), raw, olen);
|
||||||
|
}
|
||||||
|
R_API_END();
|
||||||
|
UNPROTECT(1);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
XGB_DLL SEXP XGBoosterSaveJsonConfig_R(SEXP handle) {
|
XGB_DLL SEXP XGBoosterSaveJsonConfig_R(SEXP handle) {
|
||||||
const char* ret;
|
const char* ret;
|
||||||
R_API_BEGIN();
|
R_API_BEGIN();
|
||||||
|
|||||||
@ -209,11 +209,21 @@ XGB_DLL SEXP XGBoosterSaveModel_R(SEXP handle, SEXP fname);
|
|||||||
XGB_DLL SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw);
|
XGB_DLL SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief save model into R's raw array
|
* \brief Save model into R's raw array
|
||||||
|
*
|
||||||
* \param handle handle
|
* \param handle handle
|
||||||
* \return raw array
|
* \param json_config JSON encoded string storing parameters for the function. Following
|
||||||
|
* keys are expected in the JSON document:
|
||||||
|
*
|
||||||
|
* "format": str
|
||||||
|
* - json: Output booster will be encoded as JSON.
|
||||||
|
* - ubj: Output booster will be encoded as Univeral binary JSON.
|
||||||
|
* - deprecated: Output booster will be encoded as old custom binary format. Do now use
|
||||||
|
* this format except for compatibility reasons.
|
||||||
|
*
|
||||||
|
* \return Raw array
|
||||||
*/
|
*/
|
||||||
XGB_DLL SEXP XGBoosterModelToRaw_R(SEXP handle);
|
XGB_DLL SEXP XGBoosterSaveModelToRaw_R(SEXP handle, SEXP json_config);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief Save internal parameters as a JSON string
|
* \brief Save internal parameters as a JSON string
|
||||||
|
|||||||
39
R-package/tests/testthat/test_io.R
Normal file
39
R-package/tests/testthat/test_io.R
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
context("Test model IO.")
|
||||||
|
## some other tests are in test_basic.R
|
||||||
|
require(xgboost)
|
||||||
|
require(testthat)
|
||||||
|
|
||||||
|
data(agaricus.train, package = "xgboost")
|
||||||
|
data(agaricus.test, package = "xgboost")
|
||||||
|
train <- agaricus.train
|
||||||
|
test <- agaricus.test
|
||||||
|
|
||||||
|
test_that("load/save raw works", {
|
||||||
|
nrounds <- 8
|
||||||
|
booster <- xgboost(
|
||||||
|
data = train$data, label = train$label,
|
||||||
|
nrounds = nrounds, objective = "binary:logistic"
|
||||||
|
)
|
||||||
|
|
||||||
|
json_bytes <- xgb.save.raw(booster, raw_format = "json")
|
||||||
|
ubj_bytes <- xgb.save.raw(booster, raw_format = "ubj")
|
||||||
|
old_bytes <- xgb.save.raw(booster, raw_format = "deprecated")
|
||||||
|
|
||||||
|
from_json <- xgb.load.raw(json_bytes)
|
||||||
|
from_ubj <- xgb.load.raw(ubj_bytes)
|
||||||
|
|
||||||
|
## FIXME(jiamingy): Should we include these 3 lines into `xgb.load.raw`?
|
||||||
|
from_json <- list(handle = from_json, raw = NULL)
|
||||||
|
class(from_json) <- "xgb.Booster"
|
||||||
|
from_json <- xgb.Booster.complete(from_json, saveraw = TRUE)
|
||||||
|
|
||||||
|
from_ubj <- list(handle = from_ubj, raw = NULL)
|
||||||
|
class(from_ubj) <- "xgb.Booster"
|
||||||
|
from_ubj <- xgb.Booster.complete(from_ubj, saveraw = TRUE)
|
||||||
|
|
||||||
|
json2old <- xgb.save.raw(from_json, raw_format = "deprecated")
|
||||||
|
ubj2old <- xgb.save.raw(from_ubj, raw_format = "deprecated")
|
||||||
|
|
||||||
|
expect_equal(json2old, ubj2old)
|
||||||
|
expect_equal(json2old, old_bytes)
|
||||||
|
})
|
||||||
Loading…
x
Reference in New Issue
Block a user