[R] R raw serialization. (#5123)

* Add bindings for serialization.
* Change `xgb.save.raw' into full serialization instead of simple model.
* Add `xgb.load.raw' for unserialization.
* Run devtools.
This commit is contained in:
Jiaming Yuan 2020-04-11 17:16:54 +08:00 committed by GitHub
parent a3db79df22
commit b56c902841
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 260 additions and 65 deletions

View File

@ -32,3 +32,7 @@ set_target_properties(
set(XGBOOST_DEFINITIONS "${XGBOOST_DEFINITIONS};${R_DEFINITIONS}" PARENT_SCOPE) set(XGBOOST_DEFINITIONS "${XGBOOST_DEFINITIONS};${R_DEFINITIONS}" PARENT_SCOPE)
set(XGBOOST_OBJ_SOURCES $<TARGET_OBJECTS:xgboost-r> PARENT_SCOPE) set(XGBOOST_OBJ_SOURCES $<TARGET_OBJECTS:xgboost-r> PARENT_SCOPE)
set(LINKED_LIBRARIES_PRIVATE ${LINKED_LIBRARIES_PRIVATE} ${LIBR_CORE_LIBRARY} PARENT_SCOPE) set(LINKED_LIBRARIES_PRIVATE ${LINKED_LIBRARIES_PRIVATE} ${LIBR_CORE_LIBRARY} PARENT_SCOPE)
if (USE_OPENMP)
target_link_libraries(xgboost-r PRIVATE OpenMP::OpenMP_CXX)
endif ()

View File

@ -63,5 +63,5 @@ Imports:
data.table (>= 1.9.6), data.table (>= 1.9.6),
magrittr (>= 1.5), magrittr (>= 1.5),
stringi (>= 0.5.2) stringi (>= 0.5.2)
RoxygenNote: 7.0.2 RoxygenNote: 7.1.0
SystemRequirements: GNU make, C++11 SystemRequirements: GNU make, C++11

View File

@ -40,6 +40,7 @@ export(xgb.ggplot.deepness)
export(xgb.ggplot.importance) export(xgb.ggplot.importance)
export(xgb.importance) export(xgb.importance)
export(xgb.load) export(xgb.load)
export(xgb.load.raw)
export(xgb.model.dt.tree) export(xgb.model.dt.tree)
export(xgb.plot.deepness) export(xgb.plot.deepness)
export(xgb.plot.importance) export(xgb.plot.importance)
@ -48,7 +49,9 @@ export(xgb.plot.shap)
export(xgb.plot.tree) export(xgb.plot.tree)
export(xgb.save) export(xgb.save)
export(xgb.save.raw) export(xgb.save.raw)
export(xgb.serialize)
export(xgb.train) export(xgb.train)
export(xgb.unserialize)
export(xgboost) export(xgboost)
import(methods) import(methods)
importClassesFrom(Matrix,dgCMatrix) importClassesFrom(Matrix,dgCMatrix)

View File

@ -5,20 +5,34 @@ xgb.Booster.handle <- function(params = list(), cachelist = list(), modelfile =
!all(vapply(cachelist, inherits, logical(1), what = 'xgb.DMatrix'))) { !all(vapply(cachelist, inherits, logical(1), what = 'xgb.DMatrix'))) {
stop("cachelist must be a list of xgb.DMatrix objects") stop("cachelist must be a list of xgb.DMatrix objects")
} }
## Load existing model, dispatch for on disk model file and in memory buffer
handle <- .Call(XGBoosterCreate_R, cachelist)
if (!is.null(modelfile)) { if (!is.null(modelfile)) {
if (typeof(modelfile) == "character") { if (typeof(modelfile) == "character") {
## A filename
handle <- .Call(XGBoosterCreate_R, cachelist)
.Call(XGBoosterLoadModel_R, handle, modelfile[1]) .Call(XGBoosterLoadModel_R, handle, modelfile[1])
class(handle) <- "xgb.Booster.handle"
if (length(params) > 0) {
xgb.parameters(handle) <- params
}
return(handle)
} else if (typeof(modelfile) == "raw") { } else if (typeof(modelfile) == "raw") {
.Call(XGBoosterLoadModelFromRaw_R, handle, modelfile) ## A memory buffer
bst <- xgb.unserialize(modelfile)
xgb.parameters(bst) <- params
return (bst)
} else if (inherits(modelfile, "xgb.Booster")) { } else if (inherits(modelfile, "xgb.Booster")) {
## A booster object
bst <- xgb.Booster.complete(modelfile, saveraw = TRUE) bst <- xgb.Booster.complete(modelfile, saveraw = TRUE)
.Call(XGBoosterLoadModelFromRaw_R, handle, bst$raw) bst <- xgb.unserialize(bst$raw)
xgb.parameters(bst) <- params
return (bst)
} else { } else {
stop("modelfile must be either character filename, or raw booster dump, or xgb.Booster object") stop("modelfile must be either character filename, or raw booster dump, or xgb.Booster object")
} }
} }
## Create new model
handle <- .Call(XGBoosterCreate_R, cachelist)
class(handle) <- "xgb.Booster.handle" class(handle) <- "xgb.Booster.handle"
if (length(params) > 0) { if (length(params) > 0) {
xgb.parameters(handle) <- params xgb.parameters(handle) <- params
@ -113,8 +127,9 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) {
if (is.null.handle(object$handle)) { if (is.null.handle(object$handle)) {
object$handle <- xgb.Booster.handle(modelfile = object$raw) object$handle <- xgb.Booster.handle(modelfile = object$raw)
} else { } else {
if (is.null(object$raw) && saveraw) if (is.null(object$raw) && saveraw) {
object$raw <- xgb.save.raw(object$handle) object$raw <- xgb.serialize(object$handle)
}
} }
return(object) return(object)
} }
@ -399,7 +414,7 @@ predict.xgb.Booster.handle <- function(object, ...) {
#' That would only matter if attributes need to be set many times. #' That would only matter if attributes need to be set many times.
#' Note, however, that when feeding a handle of an \code{xgb.Booster} object to the attribute setters, #' Note, however, that when feeding a handle of an \code{xgb.Booster} object to the attribute setters,
#' the raw model cache of an \code{xgb.Booster} object would not be automatically updated, #' the raw model cache of an \code{xgb.Booster} object would not be automatically updated,
#' and it would be user's responsibility to call \code{xgb.save.raw} to update it. #' and it would be user's responsibility to call \code{xgb.serialize} to update it.
#' #'
#' The \code{xgb.attributes<-} setter either updates the existing or adds one or several attributes, #' The \code{xgb.attributes<-} setter either updates the existing or adds one or several attributes,
#' but it doesn't delete the other existing attributes. #' but it doesn't delete the other existing attributes.
@ -458,7 +473,7 @@ xgb.attr <- function(object, name) {
} }
.Call(XGBoosterSetAttr_R, handle, as.character(name[1]), value) .Call(XGBoosterSetAttr_R, handle, as.character(name[1]), value)
if (is(object, 'xgb.Booster') && !is.null(object$raw)) { if (is(object, 'xgb.Booster') && !is.null(object$raw)) {
object$raw <- xgb.save.raw(object$handle) object$raw <- xgb.serialize(object$handle)
} }
object object
} }
@ -498,7 +513,7 @@ xgb.attributes <- function(object) {
.Call(XGBoosterSetAttr_R, handle, names(a[i]), a[[i]]) .Call(XGBoosterSetAttr_R, handle, names(a[i]), a[[i]])
} }
if (is(object, 'xgb.Booster') && !is.null(object$raw)) { if (is(object, 'xgb.Booster') && !is.null(object$raw)) {
object$raw <- xgb.save.raw(object$handle) object$raw <- xgb.serialize(object$handle)
} }
object object
} }
@ -528,7 +543,8 @@ xgb.config <- function(object) {
`xgb.config<-` <- function(object, value) { `xgb.config<-` <- function(object, value) {
handle <- xgb.get.handle(object) handle <- xgb.get.handle(object)
.Call(XGBoosterLoadJsonConfig_R, handle, value) .Call(XGBoosterLoadJsonConfig_R, handle, value)
object$raw <- xgb.Booster.complete(object) object$raw <- NULL # force renew the raw buffer
object <- xgb.Booster.complete(object)
object object
} }
@ -568,7 +584,7 @@ xgb.config <- function(object) {
.Call(XGBoosterSetParam_R, handle, names(p[i]), p[[i]]) .Call(XGBoosterSetParam_R, handle, names(p[i]), p[[i]])
} }
if (is(object, 'xgb.Booster') && !is.null(object$raw)) { if (is(object, 'xgb.Booster') && !is.null(object$raw)) {
object$raw <- xgb.save.raw(object$handle) object$raw <- xgb.serialize(object$handle)
} }
object object
} }

View File

@ -0,0 +1,14 @@
#' Load serialised xgboost model from R's raw vector
#'
#' User can generate raw memory buffer by calling xgb.save.raw
#'
#' @param buffer the buffer returned by xgb.save.raw
#'
#' @export
xgb.load.raw <- function(buffer) {
cachelist <- list()
handle <- .Call(XGBoosterCreate_R, cachelist)
.Call(XGBoosterLoadModelFromRaw_R, handle, buffer)
class(handle) <- "xgb.Booster.handle"
return (handle)
}

View File

@ -1,5 +1,5 @@
#' Save xgboost model to R's raw vector, #' Save xgboost model to R's raw vector,
#' user can call xgb.load to load the model back from raw vector #' user can call xgb.load.raw to load the model back from raw vector
#' #'
#' Save xgboost model from xgboost or xgb.train #' Save xgboost model from xgboost or xgb.train
#' #'
@ -13,11 +13,11 @@
#' bst <- xgboost(data = train$data, label = train$label, max_depth = 2, #' bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
#' eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic") #' eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic")
#' raw <- xgb.save.raw(bst) #' raw <- xgb.save.raw(bst)
#' bst <- xgb.load(raw) #' bst <- xgb.load.raw(raw)
#' pred <- predict(bst, test$data) #' pred <- predict(bst, test$data)
#' #'
#' @export #' @export
xgb.save.raw <- function(model) { xgb.save.raw <- function(model) {
model <- xgb.get.handle(model) handle <- xgb.get.handle(model)
.Call(XGBoosterModelToRaw_R, model) .Call(XGBoosterModelToRaw_R, handle)
} }

View File

@ -0,0 +1,21 @@
#' Serialize the booster instance into R's raw vector. The serialization method differs
#' from \code{\link{xgb.save.raw}} as the latter one saves only the model but not
#' parameters. This serialization format is not stable across different xgboost versions.
#'
#' @param booster the booster instance
#'
#' @examples
#' data(agaricus.train, package='xgboost')
#' data(agaricus.test, package='xgboost')
#' train <- agaricus.train
#' test <- agaricus.test
#' bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
#' eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic")
#' raw <- xgb.serialize(bst)
#' bst <- xgb.unserialize(raw)
#'
#' @export
xgb.serialize <- function(booster) {
handle <- xgb.get.handle(booster)
.Call(XGBoosterSerializeToBuffer_R, handle)
}

View File

@ -0,0 +1,12 @@
#' Load the instance back from \code{\link{xgb.serialize}}
#'
#' @param buffer the buffer containing booster instance saved by \code{\link{xgb.serialize}}
#'
#' @export
xgb.unserialize <- function(buffer) {
cachelist <- list()
handle <- .Call(XGBoosterCreate_R, cachelist)
.Call(XGBoosterUnserializeFromBuffer_R, handle, buffer)
class(handle) <- "xgb.Booster.handle"
return (handle)
}

View File

@ -4,8 +4,10 @@
\name{agaricus.test} \name{agaricus.test}
\alias{agaricus.test} \alias{agaricus.test}
\title{Test part from Mushroom Data Set} \title{Test part from Mushroom Data Set}
\format{A list containing a label vector, and a dgCMatrix object with 1611 \format{
rows and 126 variables} A list containing a label vector, and a dgCMatrix object with 1611
rows and 126 variables
}
\usage{ \usage{
data(agaricus.test) data(agaricus.test)
} }

View File

@ -4,8 +4,10 @@
\name{agaricus.train} \name{agaricus.train}
\alias{agaricus.train} \alias{agaricus.train}
\title{Training part from Mushroom Data Set} \title{Training part from Mushroom Data Set}
\format{A list containing a label vector, and a dgCMatrix object with 6513 \format{
rows and 127 variables} A list containing a label vector, and a dgCMatrix object with 6513
rows and 127 variables
}
\usage{ \usage{
data(agaricus.train) data(agaricus.train)
} }

View File

@ -55,7 +55,7 @@ than for \code{xgb.Booster}, since only just a handle (pointer) would need to be
That would only matter if attributes need to be set many times. That would only matter if attributes need to be set many times.
Note, however, that when feeding a handle of an \code{xgb.Booster} object to the attribute setters, Note, however, that when feeding a handle of an \code{xgb.Booster} object to the attribute setters,
the raw model cache of an \code{xgb.Booster} object would not be automatically updated, the raw model cache of an \code{xgb.Booster} object would not be automatically updated,
and it would be user's responsibility to call \code{xgb.save.raw} to update it. and it would be user's responsibility to call \code{xgb.serialize} to update it.
The \code{xgb.attributes<-} setter either updates the existing or adds one or several attributes, The \code{xgb.attributes<-} setter either updates the existing or adds one or several attributes,
but it doesn't delete the other existing attributes. but it doesn't delete the other existing attributes.

View File

@ -135,7 +135,7 @@ An object of class \code{xgb.cv.synchronous} with the following elements:
(only available with early stopping). (only available with early stopping).
\item \code{pred} CV prediction values available when \code{prediction} is set. \item \code{pred} CV prediction values available when \code{prediction} is set.
It is either vector or matrix (see \code{\link{cb.cv.predict}}). It is either vector or matrix (see \code{\link{cb.cv.predict}}).
\item \code{models} a liost of the CV folds' models. It is only available with the explicit \item \code{models} a list of the CV folds' models. It is only available with the explicit
setting of the \code{cb.cv.predict(save_models = TRUE)} callback. setting of the \code{cb.cv.predict(save_models = TRUE)} callback.
} }
} }

View File

@ -0,0 +1,14 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/xgb.load.raw.R
\name{xgb.load.raw}
\alias{xgb.load.raw}
\title{Load serialised xgboost model from R's raw vector}
\usage{
xgb.load.raw(buffer)
}
\arguments{
\item{buffer}{the buffer returned by xgb.save.raw}
}
\description{
User can generate raw memory buffer by calling xgb.save.raw
}

View File

@ -3,7 +3,7 @@
\name{xgb.save.raw} \name{xgb.save.raw}
\alias{xgb.save.raw} \alias{xgb.save.raw}
\title{Save xgboost model to R's raw vector, \title{Save xgboost model to R's raw vector,
user can call xgb.load to load the model back from raw vector} user can call xgb.load.raw to load the model back from raw vector}
\usage{ \usage{
xgb.save.raw(model) xgb.save.raw(model)
} }
@ -21,7 +21,7 @@ test <- agaricus.test
bst <- xgboost(data = train$data, label = train$label, max_depth = 2, bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic") eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic")
raw <- xgb.save.raw(bst) raw <- xgb.save.raw(bst)
bst <- xgb.load(raw) bst <- xgb.load.raw(raw)
pred <- predict(bst, test$data) pred <- predict(bst, test$data)
} }

View File

@ -0,0 +1,29 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/xgb.serialize.R
\name{xgb.serialize}
\alias{xgb.serialize}
\title{Serialize the booster instance into R's raw vector. The serialization method differs
from \code{\link{xgb.save.raw}} as the latter one saves only the model but not
parameters. This serialization format is not stable across different xgboost versions.}
\usage{
xgb.serialize(booster)
}
\arguments{
\item{booster}{the booster instance}
}
\description{
Serialize the booster instance into R's raw vector. The serialization method differs
from \code{\link{xgb.save.raw}} as the latter one saves only the model but not
parameters. This serialization format is not stable across different xgboost versions.
}
\examples{
data(agaricus.train, package='xgboost')
data(agaricus.test, package='xgboost')
train <- agaricus.train
test <- agaricus.test
bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic")
raw <- xgb.serialize(bst)
bst <- xgb.unserialize(raw)
}

View File

@ -0,0 +1,14 @@
% Generated by roxygen2: do not edit by hand
% Please edit documentation in R/xgb.unserialize.R
\name{xgb.unserialize}
\alias{xgb.unserialize}
\title{Load the instance back from \code{\link{xgb.serialize}}}
\usage{
xgb.unserialize(buffer)
}
\arguments{
\item{buffer}{the buffer containing booster instance saved by \code{\link{xgb.serialize}}}
}
\description{
Load the instance back from \code{\link{xgb.serialize}}
}

View File

@ -25,6 +25,8 @@ extern SEXP XGBoosterLoadModelFromRaw_R(SEXP, SEXP);
extern SEXP XGBoosterLoadModel_R(SEXP, SEXP); extern SEXP XGBoosterLoadModel_R(SEXP, SEXP);
extern SEXP XGBoosterSaveJsonConfig_R(SEXP handle); extern SEXP XGBoosterSaveJsonConfig_R(SEXP handle);
extern SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value); extern SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value);
extern SEXP XGBoosterSerializeToBuffer_R(SEXP handle);
extern SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw);
extern SEXP XGBoosterModelToRaw_R(SEXP); extern SEXP XGBoosterModelToRaw_R(SEXP);
extern SEXP XGBoosterPredict_R(SEXP, SEXP, SEXP, SEXP, SEXP); extern SEXP XGBoosterPredict_R(SEXP, SEXP, SEXP, SEXP, SEXP);
extern SEXP XGBoosterSaveModel_R(SEXP, SEXP); extern SEXP XGBoosterSaveModel_R(SEXP, SEXP);
@ -53,6 +55,8 @@ static const R_CallMethodDef CallEntries[] = {
{"XGBoosterLoadModel_R", (DL_FUNC) &XGBoosterLoadModel_R, 2}, {"XGBoosterLoadModel_R", (DL_FUNC) &XGBoosterLoadModel_R, 2},
{"XGBoosterSaveJsonConfig_R", (DL_FUNC) &XGBoosterSaveJsonConfig_R, 1}, {"XGBoosterSaveJsonConfig_R", (DL_FUNC) &XGBoosterSaveJsonConfig_R, 1},
{"XGBoosterLoadJsonConfig_R", (DL_FUNC) &XGBoosterLoadJsonConfig_R, 2}, {"XGBoosterLoadJsonConfig_R", (DL_FUNC) &XGBoosterLoadJsonConfig_R, 2},
{"XGBoosterSerializeToBuffer_R", (DL_FUNC) &XGBoosterSerializeToBuffer_R, 1},
{"XGBoosterUnserializeFromBuffer_R", (DL_FUNC) &XGBoosterUnserializeFromBuffer_R, 2},
{"XGBoosterModelToRaw_R", (DL_FUNC) &XGBoosterModelToRaw_R, 1}, {"XGBoosterModelToRaw_R", (DL_FUNC) &XGBoosterModelToRaw_R, 1},
{"XGBoosterPredict_R", (DL_FUNC) &XGBoosterPredict_R, 5}, {"XGBoosterPredict_R", (DL_FUNC) &XGBoosterPredict_R, 5},
{"XGBoosterSaveModel_R", (DL_FUNC) &XGBoosterSaveModel_R, 2}, {"XGBoosterSaveModel_R", (DL_FUNC) &XGBoosterSaveModel_R, 2},

View File

@ -338,15 +338,6 @@ SEXP XGBoosterSaveModel_R(SEXP handle, SEXP fname) {
return R_NilValue; return R_NilValue;
} }
SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) {
R_API_BEGIN();
CHECK_CALL(XGBoosterLoadModelFromBuffer(R_ExternalPtrAddr(handle),
RAW(raw),
length(raw)));
R_API_END();
return R_NilValue;
}
SEXP XGBoosterModelToRaw_R(SEXP handle) { SEXP XGBoosterModelToRaw_R(SEXP handle) {
SEXP ret; SEXP ret;
R_API_BEGIN(); R_API_BEGIN();
@ -362,6 +353,15 @@ SEXP XGBoosterModelToRaw_R(SEXP handle) {
return ret; return ret;
} }
SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) {
R_API_BEGIN();
CHECK_CALL(XGBoosterLoadModelFromBuffer(R_ExternalPtrAddr(handle),
RAW(raw),
length(raw)));
R_API_END();
return R_NilValue;
}
SEXP XGBoosterSaveJsonConfig_R(SEXP handle) { SEXP XGBoosterSaveJsonConfig_R(SEXP handle) {
const char* ret; const char* ret;
R_API_BEGIN(); R_API_BEGIN();
@ -380,6 +380,30 @@ SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value) {
return R_NilValue; return R_NilValue;
} }
SEXP XGBoosterSerializeToBuffer_R(SEXP handle) {
SEXP ret;
R_API_BEGIN();
bst_ulong out_len;
const char *raw;
CHECK_CALL(XGBoosterSerializeToBuffer(R_ExternalPtrAddr(handle), &out_len, &raw));
ret = PROTECT(allocVector(RAWSXP, out_len));
if (out_len != 0) {
memcpy(RAW(ret), raw, out_len);
}
R_API_END();
UNPROTECT(1);
return ret;
}
SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw) {
R_API_BEGIN();
XGBoosterUnserializeFromBuffer(R_ExternalPtrAddr(handle),
RAW(raw),
length(raw));
R_API_END();
return R_NilValue;
}
SEXP XGBoosterDumpModel_R(SEXP handle, SEXP fmap, SEXP with_stats, SEXP dump_format) { SEXP XGBoosterDumpModel_R(SEXP handle, SEXP fmap, SEXP with_stats, SEXP dump_format) {
SEXP out; SEXP out;
R_API_BEGIN(); R_API_BEGIN();

View File

@ -187,6 +187,7 @@ XGB_DLL SEXP XGBoosterModelToRaw_R(SEXP handle);
* \param handle handle * \param handle handle
* \return JSON string * \return JSON string
*/ */
XGB_DLL SEXP XGBoosterSaveJsonConfig_R(SEXP handle); XGB_DLL SEXP XGBoosterSaveJsonConfig_R(SEXP handle);
/*! /*!
* \brief Load the JSON string returnd by XGBoosterSaveJsonConfig_R * \brief Load the JSON string returnd by XGBoosterSaveJsonConfig_R
@ -195,6 +196,22 @@ XGB_DLL SEXP XGBoosterSaveJsonConfig_R(SEXP handle);
* \return R_NilValue * \return R_NilValue
*/ */
XGB_DLL SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value); XGB_DLL SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value);
/*!
* \brief Memory snapshot based serialization method. Saves everything states
* into buffer.
* \param handle handle to booster
*/
XGB_DLL SEXP XGBoosterSerializeToBuffer_R(SEXP handle);
/*!
* \brief Memory snapshot based serialization method. Loads the buffer returned
* from `XGBoosterSerializeToBuffer'.
* \param handle handle to booster
* \return raw byte array
*/
XGB_DLL SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw);
/*! /*!
* \brief dump model into a string * \brief dump model into a string
* \param handle handle * \param handle handle

View File

@ -219,6 +219,21 @@ test_that("training continuation works", {
expect_equal(dim(bst2$evaluation_log), c(2, 2)) expect_equal(dim(bst2$evaluation_log), c(2, 2))
}) })
test_that("model serialization works", {
out_path <- "model_serialization"
dtrain <- xgb.DMatrix(train$data, label = train$label)
watchlist = list(train=dtrain)
param <- list(objective = "binary:logistic")
booster <- xgb.train(param, dtrain, nrounds = 4, watchlist)
raw <- xgb.serialize(booster)
saveRDS(raw, out_path)
raw <- readRDS(out_path)
loaded <- xgb.unserialize(raw)
raw_from_loaded <- xgb.serialize(loaded)
expect_equal(raw, raw_from_loaded)
file.remove(out_path)
})
test_that("xgb.cv works", { test_that("xgb.cv works", {
set.seed(11) set.seed(11)

View File

@ -184,6 +184,9 @@ test_that("cb.save.model works as expected", {
expect_equal(xgb.ntree(b1), 1) expect_equal(xgb.ntree(b1), 1)
b2 <- xgb.load('xgboost_02.model') b2 <- xgb.load('xgboost_02.model')
expect_equal(xgb.ntree(b2), 2) expect_equal(xgb.ntree(b2), 2)
xgb.config(b2) <- xgb.config(bst)
expect_equal(xgb.config(bst), xgb.config(b2))
expect_equal(bst$raw, b2$raw) expect_equal(bst$raw, b2$raw)
# save_period = 0 saves the last iteration's model # save_period = 0 saves the last iteration's model
@ -191,6 +194,7 @@ test_that("cb.save.model works as expected", {
save_period = 0) save_period = 0)
expect_true(file.exists('xgboost.model')) expect_true(file.exists('xgboost.model'))
b2 <- xgb.load('xgboost.model') b2 <- xgb.load('xgboost.model')
xgb.config(b2) <- xgb.config(bst)
expect_equal(bst$raw, b2$raw) expect_equal(bst$raw, b2$raw)
for (f in files) if (file.exists(f)) file.remove(f) for (f in files) if (file.exists(f)) file.remove(f)