[R] R raw serialization. (#5123)
* Add bindings for serialization. * Change `xgb.save.raw' into full serialization instead of simple model. * Add `xgb.load.raw' for unserialization. * Run devtools.
This commit is contained in:
parent
a3db79df22
commit
b56c902841
@ -32,3 +32,7 @@ set_target_properties(
|
|||||||
set(XGBOOST_DEFINITIONS "${XGBOOST_DEFINITIONS};${R_DEFINITIONS}" PARENT_SCOPE)
|
set(XGBOOST_DEFINITIONS "${XGBOOST_DEFINITIONS};${R_DEFINITIONS}" PARENT_SCOPE)
|
||||||
set(XGBOOST_OBJ_SOURCES $<TARGET_OBJECTS:xgboost-r> PARENT_SCOPE)
|
set(XGBOOST_OBJ_SOURCES $<TARGET_OBJECTS:xgboost-r> PARENT_SCOPE)
|
||||||
set(LINKED_LIBRARIES_PRIVATE ${LINKED_LIBRARIES_PRIVATE} ${LIBR_CORE_LIBRARY} PARENT_SCOPE)
|
set(LINKED_LIBRARIES_PRIVATE ${LINKED_LIBRARIES_PRIVATE} ${LIBR_CORE_LIBRARY} PARENT_SCOPE)
|
||||||
|
|
||||||
|
if (USE_OPENMP)
|
||||||
|
target_link_libraries(xgboost-r PRIVATE OpenMP::OpenMP_CXX)
|
||||||
|
endif ()
|
||||||
|
|||||||
@ -63,5 +63,5 @@ Imports:
|
|||||||
data.table (>= 1.9.6),
|
data.table (>= 1.9.6),
|
||||||
magrittr (>= 1.5),
|
magrittr (>= 1.5),
|
||||||
stringi (>= 0.5.2)
|
stringi (>= 0.5.2)
|
||||||
RoxygenNote: 7.0.2
|
RoxygenNote: 7.1.0
|
||||||
SystemRequirements: GNU make, C++11
|
SystemRequirements: GNU make, C++11
|
||||||
|
|||||||
@ -40,6 +40,7 @@ export(xgb.ggplot.deepness)
|
|||||||
export(xgb.ggplot.importance)
|
export(xgb.ggplot.importance)
|
||||||
export(xgb.importance)
|
export(xgb.importance)
|
||||||
export(xgb.load)
|
export(xgb.load)
|
||||||
|
export(xgb.load.raw)
|
||||||
export(xgb.model.dt.tree)
|
export(xgb.model.dt.tree)
|
||||||
export(xgb.plot.deepness)
|
export(xgb.plot.deepness)
|
||||||
export(xgb.plot.importance)
|
export(xgb.plot.importance)
|
||||||
@ -48,7 +49,9 @@ export(xgb.plot.shap)
|
|||||||
export(xgb.plot.tree)
|
export(xgb.plot.tree)
|
||||||
export(xgb.save)
|
export(xgb.save)
|
||||||
export(xgb.save.raw)
|
export(xgb.save.raw)
|
||||||
|
export(xgb.serialize)
|
||||||
export(xgb.train)
|
export(xgb.train)
|
||||||
|
export(xgb.unserialize)
|
||||||
export(xgboost)
|
export(xgboost)
|
||||||
import(methods)
|
import(methods)
|
||||||
importClassesFrom(Matrix,dgCMatrix)
|
importClassesFrom(Matrix,dgCMatrix)
|
||||||
|
|||||||
@ -5,20 +5,34 @@ xgb.Booster.handle <- function(params = list(), cachelist = list(), modelfile =
|
|||||||
!all(vapply(cachelist, inherits, logical(1), what = 'xgb.DMatrix'))) {
|
!all(vapply(cachelist, inherits, logical(1), what = 'xgb.DMatrix'))) {
|
||||||
stop("cachelist must be a list of xgb.DMatrix objects")
|
stop("cachelist must be a list of xgb.DMatrix objects")
|
||||||
}
|
}
|
||||||
|
## Load existing model, dispatch for on disk model file and in memory buffer
|
||||||
handle <- .Call(XGBoosterCreate_R, cachelist)
|
|
||||||
if (!is.null(modelfile)) {
|
if (!is.null(modelfile)) {
|
||||||
if (typeof(modelfile) == "character") {
|
if (typeof(modelfile) == "character") {
|
||||||
|
## A filename
|
||||||
|
handle <- .Call(XGBoosterCreate_R, cachelist)
|
||||||
.Call(XGBoosterLoadModel_R, handle, modelfile[1])
|
.Call(XGBoosterLoadModel_R, handle, modelfile[1])
|
||||||
|
class(handle) <- "xgb.Booster.handle"
|
||||||
|
if (length(params) > 0) {
|
||||||
|
xgb.parameters(handle) <- params
|
||||||
|
}
|
||||||
|
return(handle)
|
||||||
} else if (typeof(modelfile) == "raw") {
|
} else if (typeof(modelfile) == "raw") {
|
||||||
.Call(XGBoosterLoadModelFromRaw_R, handle, modelfile)
|
## A memory buffer
|
||||||
|
bst <- xgb.unserialize(modelfile)
|
||||||
|
xgb.parameters(bst) <- params
|
||||||
|
return (bst)
|
||||||
} else if (inherits(modelfile, "xgb.Booster")) {
|
} else if (inherits(modelfile, "xgb.Booster")) {
|
||||||
|
## A booster object
|
||||||
bst <- xgb.Booster.complete(modelfile, saveraw = TRUE)
|
bst <- xgb.Booster.complete(modelfile, saveraw = TRUE)
|
||||||
.Call(XGBoosterLoadModelFromRaw_R, handle, bst$raw)
|
bst <- xgb.unserialize(bst$raw)
|
||||||
|
xgb.parameters(bst) <- params
|
||||||
|
return (bst)
|
||||||
} else {
|
} else {
|
||||||
stop("modelfile must be either character filename, or raw booster dump, or xgb.Booster object")
|
stop("modelfile must be either character filename, or raw booster dump, or xgb.Booster object")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
## Create new model
|
||||||
|
handle <- .Call(XGBoosterCreate_R, cachelist)
|
||||||
class(handle) <- "xgb.Booster.handle"
|
class(handle) <- "xgb.Booster.handle"
|
||||||
if (length(params) > 0) {
|
if (length(params) > 0) {
|
||||||
xgb.parameters(handle) <- params
|
xgb.parameters(handle) <- params
|
||||||
@ -113,8 +127,9 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) {
|
|||||||
if (is.null.handle(object$handle)) {
|
if (is.null.handle(object$handle)) {
|
||||||
object$handle <- xgb.Booster.handle(modelfile = object$raw)
|
object$handle <- xgb.Booster.handle(modelfile = object$raw)
|
||||||
} else {
|
} else {
|
||||||
if (is.null(object$raw) && saveraw)
|
if (is.null(object$raw) && saveraw) {
|
||||||
object$raw <- xgb.save.raw(object$handle)
|
object$raw <- xgb.serialize(object$handle)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return(object)
|
return(object)
|
||||||
}
|
}
|
||||||
@ -399,7 +414,7 @@ predict.xgb.Booster.handle <- function(object, ...) {
|
|||||||
#' That would only matter if attributes need to be set many times.
|
#' That would only matter if attributes need to be set many times.
|
||||||
#' Note, however, that when feeding a handle of an \code{xgb.Booster} object to the attribute setters,
|
#' Note, however, that when feeding a handle of an \code{xgb.Booster} object to the attribute setters,
|
||||||
#' the raw model cache of an \code{xgb.Booster} object would not be automatically updated,
|
#' the raw model cache of an \code{xgb.Booster} object would not be automatically updated,
|
||||||
#' and it would be user's responsibility to call \code{xgb.save.raw} to update it.
|
#' and it would be user's responsibility to call \code{xgb.serialize} to update it.
|
||||||
#'
|
#'
|
||||||
#' The \code{xgb.attributes<-} setter either updates the existing or adds one or several attributes,
|
#' The \code{xgb.attributes<-} setter either updates the existing or adds one or several attributes,
|
||||||
#' but it doesn't delete the other existing attributes.
|
#' but it doesn't delete the other existing attributes.
|
||||||
@ -458,7 +473,7 @@ xgb.attr <- function(object, name) {
|
|||||||
}
|
}
|
||||||
.Call(XGBoosterSetAttr_R, handle, as.character(name[1]), value)
|
.Call(XGBoosterSetAttr_R, handle, as.character(name[1]), value)
|
||||||
if (is(object, 'xgb.Booster') && !is.null(object$raw)) {
|
if (is(object, 'xgb.Booster') && !is.null(object$raw)) {
|
||||||
object$raw <- xgb.save.raw(object$handle)
|
object$raw <- xgb.serialize(object$handle)
|
||||||
}
|
}
|
||||||
object
|
object
|
||||||
}
|
}
|
||||||
@ -498,7 +513,7 @@ xgb.attributes <- function(object) {
|
|||||||
.Call(XGBoosterSetAttr_R, handle, names(a[i]), a[[i]])
|
.Call(XGBoosterSetAttr_R, handle, names(a[i]), a[[i]])
|
||||||
}
|
}
|
||||||
if (is(object, 'xgb.Booster') && !is.null(object$raw)) {
|
if (is(object, 'xgb.Booster') && !is.null(object$raw)) {
|
||||||
object$raw <- xgb.save.raw(object$handle)
|
object$raw <- xgb.serialize(object$handle)
|
||||||
}
|
}
|
||||||
object
|
object
|
||||||
}
|
}
|
||||||
@ -528,7 +543,8 @@ xgb.config <- function(object) {
|
|||||||
`xgb.config<-` <- function(object, value) {
|
`xgb.config<-` <- function(object, value) {
|
||||||
handle <- xgb.get.handle(object)
|
handle <- xgb.get.handle(object)
|
||||||
.Call(XGBoosterLoadJsonConfig_R, handle, value)
|
.Call(XGBoosterLoadJsonConfig_R, handle, value)
|
||||||
object$raw <- xgb.Booster.complete(object)
|
object$raw <- NULL # force renew the raw buffer
|
||||||
|
object <- xgb.Booster.complete(object)
|
||||||
object
|
object
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -568,7 +584,7 @@ xgb.config <- function(object) {
|
|||||||
.Call(XGBoosterSetParam_R, handle, names(p[i]), p[[i]])
|
.Call(XGBoosterSetParam_R, handle, names(p[i]), p[[i]])
|
||||||
}
|
}
|
||||||
if (is(object, 'xgb.Booster') && !is.null(object$raw)) {
|
if (is(object, 'xgb.Booster') && !is.null(object$raw)) {
|
||||||
object$raw <- xgb.save.raw(object$handle)
|
object$raw <- xgb.serialize(object$handle)
|
||||||
}
|
}
|
||||||
object
|
object
|
||||||
}
|
}
|
||||||
|
|||||||
14
R-package/R/xgb.load.raw.R
Normal file
14
R-package/R/xgb.load.raw.R
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
#' Load serialised xgboost model from R's raw vector
|
||||||
|
#'
|
||||||
|
#' User can generate raw memory buffer by calling xgb.save.raw
|
||||||
|
#'
|
||||||
|
#' @param buffer the buffer returned by xgb.save.raw
|
||||||
|
#'
|
||||||
|
#' @export
|
||||||
|
xgb.load.raw <- function(buffer) {
|
||||||
|
cachelist <- list()
|
||||||
|
handle <- .Call(XGBoosterCreate_R, cachelist)
|
||||||
|
.Call(XGBoosterLoadModelFromRaw_R, handle, buffer)
|
||||||
|
class(handle) <- "xgb.Booster.handle"
|
||||||
|
return (handle)
|
||||||
|
}
|
||||||
@ -1,23 +1,23 @@
|
|||||||
#' Save xgboost model to R's raw vector,
|
#' Save xgboost model to R's raw vector,
|
||||||
#' user can call xgb.load to load the model back from raw vector
|
#' user can call xgb.load.raw to load the model back from raw vector
|
||||||
#'
|
#'
|
||||||
#' Save xgboost model from xgboost or xgb.train
|
#' Save xgboost model from xgboost or xgb.train
|
||||||
#'
|
#'
|
||||||
#' @param model the model object.
|
#' @param model the model object.
|
||||||
#'
|
#'
|
||||||
#' @examples
|
#' @examples
|
||||||
#' data(agaricus.train, package='xgboost')
|
#' data(agaricus.train, package='xgboost')
|
||||||
#' data(agaricus.test, package='xgboost')
|
#' data(agaricus.test, package='xgboost')
|
||||||
#' train <- agaricus.train
|
#' train <- agaricus.train
|
||||||
#' test <- agaricus.test
|
#' test <- agaricus.test
|
||||||
#' bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
|
#' bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
|
||||||
#' eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic")
|
#' eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic")
|
||||||
#' raw <- xgb.save.raw(bst)
|
#' raw <- xgb.save.raw(bst)
|
||||||
#' bst <- xgb.load(raw)
|
#' bst <- xgb.load.raw(raw)
|
||||||
#' pred <- predict(bst, test$data)
|
#' pred <- predict(bst, test$data)
|
||||||
#'
|
#'
|
||||||
#' @export
|
#' @export
|
||||||
xgb.save.raw <- function(model) {
|
xgb.save.raw <- function(model) {
|
||||||
model <- xgb.get.handle(model)
|
handle <- xgb.get.handle(model)
|
||||||
.Call(XGBoosterModelToRaw_R, model)
|
.Call(XGBoosterModelToRaw_R, handle)
|
||||||
}
|
}
|
||||||
|
|||||||
21
R-package/R/xgb.serialize.R
Normal file
21
R-package/R/xgb.serialize.R
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
#' Serialize the booster instance into R's raw vector. The serialization method differs
|
||||||
|
#' from \code{\link{xgb.save.raw}} as the latter one saves only the model but not
|
||||||
|
#' parameters. This serialization format is not stable across different xgboost versions.
|
||||||
|
#'
|
||||||
|
#' @param booster the booster instance
|
||||||
|
#'
|
||||||
|
#' @examples
|
||||||
|
#' data(agaricus.train, package='xgboost')
|
||||||
|
#' data(agaricus.test, package='xgboost')
|
||||||
|
#' train <- agaricus.train
|
||||||
|
#' test <- agaricus.test
|
||||||
|
#' bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
|
||||||
|
#' eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic")
|
||||||
|
#' raw <- xgb.serialize(bst)
|
||||||
|
#' bst <- xgb.unserialize(raw)
|
||||||
|
#'
|
||||||
|
#' @export
|
||||||
|
xgb.serialize <- function(booster) {
|
||||||
|
handle <- xgb.get.handle(booster)
|
||||||
|
.Call(XGBoosterSerializeToBuffer_R, handle)
|
||||||
|
}
|
||||||
12
R-package/R/xgb.unserialize.R
Normal file
12
R-package/R/xgb.unserialize.R
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
#' Load the instance back from \code{\link{xgb.serialize}}
|
||||||
|
#'
|
||||||
|
#' @param buffer the buffer containing booster instance saved by \code{\link{xgb.serialize}}
|
||||||
|
#'
|
||||||
|
#' @export
|
||||||
|
xgb.unserialize <- function(buffer) {
|
||||||
|
cachelist <- list()
|
||||||
|
handle <- .Call(XGBoosterCreate_R, cachelist)
|
||||||
|
.Call(XGBoosterUnserializeFromBuffer_R, handle, buffer)
|
||||||
|
class(handle) <- "xgb.Booster.handle"
|
||||||
|
return (handle)
|
||||||
|
}
|
||||||
@ -4,8 +4,10 @@
|
|||||||
\name{agaricus.test}
|
\name{agaricus.test}
|
||||||
\alias{agaricus.test}
|
\alias{agaricus.test}
|
||||||
\title{Test part from Mushroom Data Set}
|
\title{Test part from Mushroom Data Set}
|
||||||
\format{A list containing a label vector, and a dgCMatrix object with 1611
|
\format{
|
||||||
rows and 126 variables}
|
A list containing a label vector, and a dgCMatrix object with 1611
|
||||||
|
rows and 126 variables
|
||||||
|
}
|
||||||
\usage{
|
\usage{
|
||||||
data(agaricus.test)
|
data(agaricus.test)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -4,8 +4,10 @@
|
|||||||
\name{agaricus.train}
|
\name{agaricus.train}
|
||||||
\alias{agaricus.train}
|
\alias{agaricus.train}
|
||||||
\title{Training part from Mushroom Data Set}
|
\title{Training part from Mushroom Data Set}
|
||||||
\format{A list containing a label vector, and a dgCMatrix object with 6513
|
\format{
|
||||||
rows and 127 variables}
|
A list containing a label vector, and a dgCMatrix object with 6513
|
||||||
|
rows and 127 variables
|
||||||
|
}
|
||||||
\usage{
|
\usage{
|
||||||
data(agaricus.train)
|
data(agaricus.train)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -55,7 +55,7 @@ than for \code{xgb.Booster}, since only just a handle (pointer) would need to be
|
|||||||
That would only matter if attributes need to be set many times.
|
That would only matter if attributes need to be set many times.
|
||||||
Note, however, that when feeding a handle of an \code{xgb.Booster} object to the attribute setters,
|
Note, however, that when feeding a handle of an \code{xgb.Booster} object to the attribute setters,
|
||||||
the raw model cache of an \code{xgb.Booster} object would not be automatically updated,
|
the raw model cache of an \code{xgb.Booster} object would not be automatically updated,
|
||||||
and it would be user's responsibility to call \code{xgb.save.raw} to update it.
|
and it would be user's responsibility to call \code{xgb.serialize} to update it.
|
||||||
|
|
||||||
The \code{xgb.attributes<-} setter either updates the existing or adds one or several attributes,
|
The \code{xgb.attributes<-} setter either updates the existing or adds one or several attributes,
|
||||||
but it doesn't delete the other existing attributes.
|
but it doesn't delete the other existing attributes.
|
||||||
|
|||||||
@ -135,7 +135,7 @@ An object of class \code{xgb.cv.synchronous} with the following elements:
|
|||||||
(only available with early stopping).
|
(only available with early stopping).
|
||||||
\item \code{pred} CV prediction values available when \code{prediction} is set.
|
\item \code{pred} CV prediction values available when \code{prediction} is set.
|
||||||
It is either vector or matrix (see \code{\link{cb.cv.predict}}).
|
It is either vector or matrix (see \code{\link{cb.cv.predict}}).
|
||||||
\item \code{models} a liost of the CV folds' models. It is only available with the explicit
|
\item \code{models} a list of the CV folds' models. It is only available with the explicit
|
||||||
setting of the \code{cb.cv.predict(save_models = TRUE)} callback.
|
setting of the \code{cb.cv.predict(save_models = TRUE)} callback.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
14
R-package/man/xgb.load.raw.Rd
Normal file
14
R-package/man/xgb.load.raw.Rd
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
% Generated by roxygen2: do not edit by hand
|
||||||
|
% Please edit documentation in R/xgb.load.raw.R
|
||||||
|
\name{xgb.load.raw}
|
||||||
|
\alias{xgb.load.raw}
|
||||||
|
\title{Load serialised xgboost model from R's raw vector}
|
||||||
|
\usage{
|
||||||
|
xgb.load.raw(buffer)
|
||||||
|
}
|
||||||
|
\arguments{
|
||||||
|
\item{buffer}{the buffer returned by xgb.save.raw}
|
||||||
|
}
|
||||||
|
\description{
|
||||||
|
User can generate raw memory buffer by calling xgb.save.raw
|
||||||
|
}
|
||||||
@ -3,7 +3,7 @@
|
|||||||
\name{xgb.save.raw}
|
\name{xgb.save.raw}
|
||||||
\alias{xgb.save.raw}
|
\alias{xgb.save.raw}
|
||||||
\title{Save xgboost model to R's raw vector,
|
\title{Save xgboost model to R's raw vector,
|
||||||
user can call xgb.load to load the model back from raw vector}
|
user can call xgb.load.raw to load the model back from raw vector}
|
||||||
\usage{
|
\usage{
|
||||||
xgb.save.raw(model)
|
xgb.save.raw(model)
|
||||||
}
|
}
|
||||||
@ -18,10 +18,10 @@ data(agaricus.train, package='xgboost')
|
|||||||
data(agaricus.test, package='xgboost')
|
data(agaricus.test, package='xgboost')
|
||||||
train <- agaricus.train
|
train <- agaricus.train
|
||||||
test <- agaricus.test
|
test <- agaricus.test
|
||||||
bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
|
bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
|
||||||
eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic")
|
eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic")
|
||||||
raw <- xgb.save.raw(bst)
|
raw <- xgb.save.raw(bst)
|
||||||
bst <- xgb.load(raw)
|
bst <- xgb.load.raw(raw)
|
||||||
pred <- predict(bst, test$data)
|
pred <- predict(bst, test$data)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
29
R-package/man/xgb.serialize.Rd
Normal file
29
R-package/man/xgb.serialize.Rd
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
% Generated by roxygen2: do not edit by hand
|
||||||
|
% Please edit documentation in R/xgb.serialize.R
|
||||||
|
\name{xgb.serialize}
|
||||||
|
\alias{xgb.serialize}
|
||||||
|
\title{Serialize the booster instance into R's raw vector. The serialization method differs
|
||||||
|
from \code{\link{xgb.save.raw}} as the latter one saves only the model but not
|
||||||
|
parameters. This serialization format is not stable across different xgboost versions.}
|
||||||
|
\usage{
|
||||||
|
xgb.serialize(booster)
|
||||||
|
}
|
||||||
|
\arguments{
|
||||||
|
\item{booster}{the booster instance}
|
||||||
|
}
|
||||||
|
\description{
|
||||||
|
Serialize the booster instance into R's raw vector. The serialization method differs
|
||||||
|
from \code{\link{xgb.save.raw}} as the latter one saves only the model but not
|
||||||
|
parameters. This serialization format is not stable across different xgboost versions.
|
||||||
|
}
|
||||||
|
\examples{
|
||||||
|
data(agaricus.train, package='xgboost')
|
||||||
|
data(agaricus.test, package='xgboost')
|
||||||
|
train <- agaricus.train
|
||||||
|
test <- agaricus.test
|
||||||
|
bst <- xgboost(data = train$data, label = train$label, max_depth = 2,
|
||||||
|
eta = 1, nthread = 2, nrounds = 2,objective = "binary:logistic")
|
||||||
|
raw <- xgb.serialize(bst)
|
||||||
|
bst <- xgb.unserialize(raw)
|
||||||
|
|
||||||
|
}
|
||||||
14
R-package/man/xgb.unserialize.Rd
Normal file
14
R-package/man/xgb.unserialize.Rd
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
% Generated by roxygen2: do not edit by hand
|
||||||
|
% Please edit documentation in R/xgb.unserialize.R
|
||||||
|
\name{xgb.unserialize}
|
||||||
|
\alias{xgb.unserialize}
|
||||||
|
\title{Load the instance back from \code{\link{xgb.serialize}}}
|
||||||
|
\usage{
|
||||||
|
xgb.unserialize(buffer)
|
||||||
|
}
|
||||||
|
\arguments{
|
||||||
|
\item{buffer}{the buffer containing booster instance saved by \code{\link{xgb.serialize}}}
|
||||||
|
}
|
||||||
|
\description{
|
||||||
|
Load the instance back from \code{\link{xgb.serialize}}
|
||||||
|
}
|
||||||
@ -25,6 +25,8 @@ extern SEXP XGBoosterLoadModelFromRaw_R(SEXP, SEXP);
|
|||||||
extern SEXP XGBoosterLoadModel_R(SEXP, SEXP);
|
extern SEXP XGBoosterLoadModel_R(SEXP, SEXP);
|
||||||
extern SEXP XGBoosterSaveJsonConfig_R(SEXP handle);
|
extern SEXP XGBoosterSaveJsonConfig_R(SEXP handle);
|
||||||
extern SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value);
|
extern SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value);
|
||||||
|
extern SEXP XGBoosterSerializeToBuffer_R(SEXP handle);
|
||||||
|
extern SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw);
|
||||||
extern SEXP XGBoosterModelToRaw_R(SEXP);
|
extern SEXP XGBoosterModelToRaw_R(SEXP);
|
||||||
extern SEXP XGBoosterPredict_R(SEXP, SEXP, SEXP, SEXP, SEXP);
|
extern SEXP XGBoosterPredict_R(SEXP, SEXP, SEXP, SEXP, SEXP);
|
||||||
extern SEXP XGBoosterSaveModel_R(SEXP, SEXP);
|
extern SEXP XGBoosterSaveModel_R(SEXP, SEXP);
|
||||||
@ -53,6 +55,8 @@ static const R_CallMethodDef CallEntries[] = {
|
|||||||
{"XGBoosterLoadModel_R", (DL_FUNC) &XGBoosterLoadModel_R, 2},
|
{"XGBoosterLoadModel_R", (DL_FUNC) &XGBoosterLoadModel_R, 2},
|
||||||
{"XGBoosterSaveJsonConfig_R", (DL_FUNC) &XGBoosterSaveJsonConfig_R, 1},
|
{"XGBoosterSaveJsonConfig_R", (DL_FUNC) &XGBoosterSaveJsonConfig_R, 1},
|
||||||
{"XGBoosterLoadJsonConfig_R", (DL_FUNC) &XGBoosterLoadJsonConfig_R, 2},
|
{"XGBoosterLoadJsonConfig_R", (DL_FUNC) &XGBoosterLoadJsonConfig_R, 2},
|
||||||
|
{"XGBoosterSerializeToBuffer_R", (DL_FUNC) &XGBoosterSerializeToBuffer_R, 1},
|
||||||
|
{"XGBoosterUnserializeFromBuffer_R", (DL_FUNC) &XGBoosterUnserializeFromBuffer_R, 2},
|
||||||
{"XGBoosterModelToRaw_R", (DL_FUNC) &XGBoosterModelToRaw_R, 1},
|
{"XGBoosterModelToRaw_R", (DL_FUNC) &XGBoosterModelToRaw_R, 1},
|
||||||
{"XGBoosterPredict_R", (DL_FUNC) &XGBoosterPredict_R, 5},
|
{"XGBoosterPredict_R", (DL_FUNC) &XGBoosterPredict_R, 5},
|
||||||
{"XGBoosterSaveModel_R", (DL_FUNC) &XGBoosterSaveModel_R, 2},
|
{"XGBoosterSaveModel_R", (DL_FUNC) &XGBoosterSaveModel_R, 2},
|
||||||
|
|||||||
@ -338,15 +338,6 @@ SEXP XGBoosterSaveModel_R(SEXP handle, SEXP fname) {
|
|||||||
return R_NilValue;
|
return R_NilValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) {
|
|
||||||
R_API_BEGIN();
|
|
||||||
CHECK_CALL(XGBoosterLoadModelFromBuffer(R_ExternalPtrAddr(handle),
|
|
||||||
RAW(raw),
|
|
||||||
length(raw)));
|
|
||||||
R_API_END();
|
|
||||||
return R_NilValue;
|
|
||||||
}
|
|
||||||
|
|
||||||
SEXP XGBoosterModelToRaw_R(SEXP handle) {
|
SEXP XGBoosterModelToRaw_R(SEXP handle) {
|
||||||
SEXP ret;
|
SEXP ret;
|
||||||
R_API_BEGIN();
|
R_API_BEGIN();
|
||||||
@ -362,6 +353,15 @@ SEXP XGBoosterModelToRaw_R(SEXP handle) {
|
|||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) {
|
||||||
|
R_API_BEGIN();
|
||||||
|
CHECK_CALL(XGBoosterLoadModelFromBuffer(R_ExternalPtrAddr(handle),
|
||||||
|
RAW(raw),
|
||||||
|
length(raw)));
|
||||||
|
R_API_END();
|
||||||
|
return R_NilValue;
|
||||||
|
}
|
||||||
|
|
||||||
SEXP XGBoosterSaveJsonConfig_R(SEXP handle) {
|
SEXP XGBoosterSaveJsonConfig_R(SEXP handle) {
|
||||||
const char* ret;
|
const char* ret;
|
||||||
R_API_BEGIN();
|
R_API_BEGIN();
|
||||||
@ -380,6 +380,30 @@ SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value) {
|
|||||||
return R_NilValue;
|
return R_NilValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SEXP XGBoosterSerializeToBuffer_R(SEXP handle) {
|
||||||
|
SEXP ret;
|
||||||
|
R_API_BEGIN();
|
||||||
|
bst_ulong out_len;
|
||||||
|
const char *raw;
|
||||||
|
CHECK_CALL(XGBoosterSerializeToBuffer(R_ExternalPtrAddr(handle), &out_len, &raw));
|
||||||
|
ret = PROTECT(allocVector(RAWSXP, out_len));
|
||||||
|
if (out_len != 0) {
|
||||||
|
memcpy(RAW(ret), raw, out_len);
|
||||||
|
}
|
||||||
|
R_API_END();
|
||||||
|
UNPROTECT(1);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw) {
|
||||||
|
R_API_BEGIN();
|
||||||
|
XGBoosterUnserializeFromBuffer(R_ExternalPtrAddr(handle),
|
||||||
|
RAW(raw),
|
||||||
|
length(raw));
|
||||||
|
R_API_END();
|
||||||
|
return R_NilValue;
|
||||||
|
}
|
||||||
|
|
||||||
SEXP XGBoosterDumpModel_R(SEXP handle, SEXP fmap, SEXP with_stats, SEXP dump_format) {
|
SEXP XGBoosterDumpModel_R(SEXP handle, SEXP fmap, SEXP with_stats, SEXP dump_format) {
|
||||||
SEXP out;
|
SEXP out;
|
||||||
R_API_BEGIN();
|
R_API_BEGIN();
|
||||||
|
|||||||
@ -187,6 +187,7 @@ XGB_DLL SEXP XGBoosterModelToRaw_R(SEXP handle);
|
|||||||
* \param handle handle
|
* \param handle handle
|
||||||
* \return JSON string
|
* \return JSON string
|
||||||
*/
|
*/
|
||||||
|
|
||||||
XGB_DLL SEXP XGBoosterSaveJsonConfig_R(SEXP handle);
|
XGB_DLL SEXP XGBoosterSaveJsonConfig_R(SEXP handle);
|
||||||
/*!
|
/*!
|
||||||
* \brief Load the JSON string returnd by XGBoosterSaveJsonConfig_R
|
* \brief Load the JSON string returnd by XGBoosterSaveJsonConfig_R
|
||||||
@ -195,6 +196,22 @@ XGB_DLL SEXP XGBoosterSaveJsonConfig_R(SEXP handle);
|
|||||||
* \return R_NilValue
|
* \return R_NilValue
|
||||||
*/
|
*/
|
||||||
XGB_DLL SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value);
|
XGB_DLL SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Memory snapshot based serialization method. Saves everything states
|
||||||
|
* into buffer.
|
||||||
|
* \param handle handle to booster
|
||||||
|
*/
|
||||||
|
XGB_DLL SEXP XGBoosterSerializeToBuffer_R(SEXP handle);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Memory snapshot based serialization method. Loads the buffer returned
|
||||||
|
* from `XGBoosterSerializeToBuffer'.
|
||||||
|
* \param handle handle to booster
|
||||||
|
* \return raw byte array
|
||||||
|
*/
|
||||||
|
XGB_DLL SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw);
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief dump model into a string
|
* \brief dump model into a string
|
||||||
* \param handle handle
|
* \param handle handle
|
||||||
|
|||||||
@ -219,6 +219,21 @@ test_that("training continuation works", {
|
|||||||
expect_equal(dim(bst2$evaluation_log), c(2, 2))
|
expect_equal(dim(bst2$evaluation_log), c(2, 2))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test_that("model serialization works", {
|
||||||
|
out_path <- "model_serialization"
|
||||||
|
dtrain <- xgb.DMatrix(train$data, label = train$label)
|
||||||
|
watchlist = list(train=dtrain)
|
||||||
|
param <- list(objective = "binary:logistic")
|
||||||
|
booster <- xgb.train(param, dtrain, nrounds = 4, watchlist)
|
||||||
|
raw <- xgb.serialize(booster)
|
||||||
|
saveRDS(raw, out_path)
|
||||||
|
raw <- readRDS(out_path)
|
||||||
|
|
||||||
|
loaded <- xgb.unserialize(raw)
|
||||||
|
raw_from_loaded <- xgb.serialize(loaded)
|
||||||
|
expect_equal(raw, raw_from_loaded)
|
||||||
|
file.remove(out_path)
|
||||||
|
})
|
||||||
|
|
||||||
test_that("xgb.cv works", {
|
test_that("xgb.cv works", {
|
||||||
set.seed(11)
|
set.seed(11)
|
||||||
|
|||||||
@ -30,16 +30,16 @@ param <- list(objective = "binary:logistic", max_depth = 2, nthread = 2)
|
|||||||
|
|
||||||
|
|
||||||
test_that("cb.print.evaluation works as expected", {
|
test_that("cb.print.evaluation works as expected", {
|
||||||
|
|
||||||
bst_evaluation <- c('train-auc'=0.9, 'test-auc'=0.8)
|
bst_evaluation <- c('train-auc'=0.9, 'test-auc'=0.8)
|
||||||
bst_evaluation_err <- NULL
|
bst_evaluation_err <- NULL
|
||||||
begin_iteration <- 1
|
begin_iteration <- 1
|
||||||
end_iteration <- 7
|
end_iteration <- 7
|
||||||
|
|
||||||
f0 <- cb.print.evaluation(period=0)
|
f0 <- cb.print.evaluation(period=0)
|
||||||
f1 <- cb.print.evaluation(period=1)
|
f1 <- cb.print.evaluation(period=1)
|
||||||
f5 <- cb.print.evaluation(period=5)
|
f5 <- cb.print.evaluation(period=5)
|
||||||
|
|
||||||
expect_false(is.null(attr(f1, 'call')))
|
expect_false(is.null(attr(f1, 'call')))
|
||||||
expect_equal(attr(f1, 'name'), 'cb.print.evaluation')
|
expect_equal(attr(f1, 'name'), 'cb.print.evaluation')
|
||||||
|
|
||||||
@ -48,15 +48,15 @@ test_that("cb.print.evaluation works as expected", {
|
|||||||
expect_output(f1(), "\\[1\\]\ttrain-auc:0.900000\ttest-auc:0.800000")
|
expect_output(f1(), "\\[1\\]\ttrain-auc:0.900000\ttest-auc:0.800000")
|
||||||
expect_output(f5(), "\\[1\\]\ttrain-auc:0.900000\ttest-auc:0.800000")
|
expect_output(f5(), "\\[1\\]\ttrain-auc:0.900000\ttest-auc:0.800000")
|
||||||
expect_null(f1())
|
expect_null(f1())
|
||||||
|
|
||||||
iteration <- 2
|
iteration <- 2
|
||||||
expect_output(f1(), "\\[2\\]\ttrain-auc:0.900000\ttest-auc:0.800000")
|
expect_output(f1(), "\\[2\\]\ttrain-auc:0.900000\ttest-auc:0.800000")
|
||||||
expect_silent(f5())
|
expect_silent(f5())
|
||||||
|
|
||||||
iteration <- 7
|
iteration <- 7
|
||||||
expect_output(f1(), "\\[7\\]\ttrain-auc:0.900000\ttest-auc:0.800000")
|
expect_output(f1(), "\\[7\\]\ttrain-auc:0.900000\ttest-auc:0.800000")
|
||||||
expect_output(f5(), "\\[7\\]\ttrain-auc:0.900000\ttest-auc:0.800000")
|
expect_output(f5(), "\\[7\\]\ttrain-auc:0.900000\ttest-auc:0.800000")
|
||||||
|
|
||||||
bst_evaluation_err <- c('train-auc'=0.1, 'test-auc'=0.2)
|
bst_evaluation_err <- c('train-auc'=0.1, 'test-auc'=0.2)
|
||||||
expect_output(f1(), "\\[7\\]\ttrain-auc:0.900000\\+0.100000\ttest-auc:0.800000\\+0.200000")
|
expect_output(f1(), "\\[7\\]\ttrain-auc:0.900000\\+0.100000\ttest-auc:0.800000\\+0.200000")
|
||||||
})
|
})
|
||||||
@ -65,40 +65,40 @@ test_that("cb.evaluation.log works as expected", {
|
|||||||
|
|
||||||
bst_evaluation <- c('train-auc'=0.9, 'test-auc'=0.8)
|
bst_evaluation <- c('train-auc'=0.9, 'test-auc'=0.8)
|
||||||
bst_evaluation_err <- NULL
|
bst_evaluation_err <- NULL
|
||||||
|
|
||||||
evaluation_log <- list()
|
evaluation_log <- list()
|
||||||
f <- cb.evaluation.log()
|
f <- cb.evaluation.log()
|
||||||
|
|
||||||
expect_false(is.null(attr(f, 'call')))
|
expect_false(is.null(attr(f, 'call')))
|
||||||
expect_equal(attr(f, 'name'), 'cb.evaluation.log')
|
expect_equal(attr(f, 'name'), 'cb.evaluation.log')
|
||||||
|
|
||||||
iteration <- 1
|
iteration <- 1
|
||||||
expect_silent(f())
|
expect_silent(f())
|
||||||
expect_equal(evaluation_log,
|
expect_equal(evaluation_log,
|
||||||
list(c(iter=1, bst_evaluation)))
|
list(c(iter=1, bst_evaluation)))
|
||||||
iteration <- 2
|
iteration <- 2
|
||||||
expect_silent(f())
|
expect_silent(f())
|
||||||
expect_equal(evaluation_log,
|
expect_equal(evaluation_log,
|
||||||
list(c(iter=1, bst_evaluation), c(iter=2, bst_evaluation)))
|
list(c(iter=1, bst_evaluation), c(iter=2, bst_evaluation)))
|
||||||
expect_silent(f(finalize = TRUE))
|
expect_silent(f(finalize = TRUE))
|
||||||
expect_equal(evaluation_log,
|
expect_equal(evaluation_log,
|
||||||
data.table(iter=1:2, train_auc=c(0.9,0.9), test_auc=c(0.8,0.8)))
|
data.table(iter=1:2, train_auc=c(0.9,0.9), test_auc=c(0.8,0.8)))
|
||||||
|
|
||||||
bst_evaluation_err <- c('train-auc'=0.1, 'test-auc'=0.2)
|
bst_evaluation_err <- c('train-auc'=0.1, 'test-auc'=0.2)
|
||||||
evaluation_log <- list()
|
evaluation_log <- list()
|
||||||
f <- cb.evaluation.log()
|
f <- cb.evaluation.log()
|
||||||
|
|
||||||
iteration <- 1
|
iteration <- 1
|
||||||
expect_silent(f())
|
expect_silent(f())
|
||||||
expect_equal(evaluation_log,
|
expect_equal(evaluation_log,
|
||||||
list(c(iter=1, c(bst_evaluation, bst_evaluation_err))))
|
list(c(iter=1, c(bst_evaluation, bst_evaluation_err))))
|
||||||
iteration <- 2
|
iteration <- 2
|
||||||
expect_silent(f())
|
expect_silent(f())
|
||||||
expect_equal(evaluation_log,
|
expect_equal(evaluation_log,
|
||||||
list(c(iter=1, c(bst_evaluation, bst_evaluation_err)),
|
list(c(iter=1, c(bst_evaluation, bst_evaluation_err)),
|
||||||
c(iter=2, c(bst_evaluation, bst_evaluation_err))))
|
c(iter=2, c(bst_evaluation, bst_evaluation_err))))
|
||||||
expect_silent(f(finalize = TRUE))
|
expect_silent(f(finalize = TRUE))
|
||||||
expect_equal(evaluation_log,
|
expect_equal(evaluation_log,
|
||||||
data.table(iter=1:2,
|
data.table(iter=1:2,
|
||||||
train_auc_mean=c(0.9,0.9), train_auc_std=c(0.1,0.1),
|
train_auc_mean=c(0.9,0.9), train_auc_std=c(0.1,0.1),
|
||||||
test_auc_mean=c(0.8,0.8), test_auc_std=c(0.2,0.2)))
|
test_auc_mean=c(0.8,0.8), test_auc_std=c(0.2,0.2)))
|
||||||
@ -130,18 +130,18 @@ test_that("cb.reset.parameters works as expected", {
|
|||||||
bst1 <- xgb.train(param, dtrain, nrounds = 2, watchlist, verbose = 0,
|
bst1 <- xgb.train(param, dtrain, nrounds = 2, watchlist, verbose = 0,
|
||||||
callbacks = list(cb.reset.parameters(my_par)))
|
callbacks = list(cb.reset.parameters(my_par)))
|
||||||
expect_false(is.null(bst1$evaluation_log$train_error))
|
expect_false(is.null(bst1$evaluation_log$train_error))
|
||||||
expect_equal(bst0$evaluation_log$train_error,
|
expect_equal(bst0$evaluation_log$train_error,
|
||||||
bst1$evaluation_log$train_error)
|
bst1$evaluation_log$train_error)
|
||||||
|
|
||||||
# same eta but re-set via a function in the callback
|
# same eta but re-set via a function in the callback
|
||||||
set.seed(111)
|
set.seed(111)
|
||||||
my_par <- list(eta = function(itr, itr_end) 0.9)
|
my_par <- list(eta = function(itr, itr_end) 0.9)
|
||||||
bst2 <- xgb.train(param, dtrain, nrounds = 2, watchlist, verbose = 0,
|
bst2 <- xgb.train(param, dtrain, nrounds = 2, watchlist, verbose = 0,
|
||||||
callbacks = list(cb.reset.parameters(my_par)))
|
callbacks = list(cb.reset.parameters(my_par)))
|
||||||
expect_false(is.null(bst2$evaluation_log$train_error))
|
expect_false(is.null(bst2$evaluation_log$train_error))
|
||||||
expect_equal(bst0$evaluation_log$train_error,
|
expect_equal(bst0$evaluation_log$train_error,
|
||||||
bst2$evaluation_log$train_error)
|
bst2$evaluation_log$train_error)
|
||||||
|
|
||||||
# different eta re-set as a vector parameter in the callback
|
# different eta re-set as a vector parameter in the callback
|
||||||
set.seed(111)
|
set.seed(111)
|
||||||
my_par <- list(eta = c(0.6, 0.5))
|
my_par <- list(eta = c(0.6, 0.5))
|
||||||
@ -149,7 +149,7 @@ test_that("cb.reset.parameters works as expected", {
|
|||||||
callbacks = list(cb.reset.parameters(my_par)))
|
callbacks = list(cb.reset.parameters(my_par)))
|
||||||
expect_false(is.null(bst3$evaluation_log$train_error))
|
expect_false(is.null(bst3$evaluation_log$train_error))
|
||||||
expect_false(all(bst0$evaluation_log$train_error == bst3$evaluation_log$train_error))
|
expect_false(all(bst0$evaluation_log$train_error == bst3$evaluation_log$train_error))
|
||||||
|
|
||||||
# resetting multiple parameters at the same time runs with no error
|
# resetting multiple parameters at the same time runs with no error
|
||||||
my_par <- list(eta = c(1., 0.5), gamma = c(1, 2), max_depth = c(4, 8))
|
my_par <- list(eta = c(1., 0.5), gamma = c(1, 2), max_depth = c(4, 8))
|
||||||
expect_error(
|
expect_error(
|
||||||
@ -175,7 +175,7 @@ test_that("cb.reset.parameters works as expected", {
|
|||||||
test_that("cb.save.model works as expected", {
|
test_that("cb.save.model works as expected", {
|
||||||
files <- c('xgboost_01.model', 'xgboost_02.model', 'xgboost.model')
|
files <- c('xgboost_01.model', 'xgboost_02.model', 'xgboost.model')
|
||||||
for (f in files) if (file.exists(f)) file.remove(f)
|
for (f in files) if (file.exists(f)) file.remove(f)
|
||||||
|
|
||||||
bst <- xgb.train(param, dtrain, nrounds = 2, watchlist, eta = 1, verbose = 0,
|
bst <- xgb.train(param, dtrain, nrounds = 2, watchlist, eta = 1, verbose = 0,
|
||||||
save_period = 1, save_name = "xgboost_%02d.model")
|
save_period = 1, save_name = "xgboost_%02d.model")
|
||||||
expect_true(file.exists('xgboost_01.model'))
|
expect_true(file.exists('xgboost_01.model'))
|
||||||
@ -184,6 +184,9 @@ test_that("cb.save.model works as expected", {
|
|||||||
expect_equal(xgb.ntree(b1), 1)
|
expect_equal(xgb.ntree(b1), 1)
|
||||||
b2 <- xgb.load('xgboost_02.model')
|
b2 <- xgb.load('xgboost_02.model')
|
||||||
expect_equal(xgb.ntree(b2), 2)
|
expect_equal(xgb.ntree(b2), 2)
|
||||||
|
|
||||||
|
xgb.config(b2) <- xgb.config(bst)
|
||||||
|
expect_equal(xgb.config(bst), xgb.config(b2))
|
||||||
expect_equal(bst$raw, b2$raw)
|
expect_equal(bst$raw, b2$raw)
|
||||||
|
|
||||||
# save_period = 0 saves the last iteration's model
|
# save_period = 0 saves the last iteration's model
|
||||||
@ -191,8 +194,9 @@ test_that("cb.save.model works as expected", {
|
|||||||
save_period = 0)
|
save_period = 0)
|
||||||
expect_true(file.exists('xgboost.model'))
|
expect_true(file.exists('xgboost.model'))
|
||||||
b2 <- xgb.load('xgboost.model')
|
b2 <- xgb.load('xgboost.model')
|
||||||
|
xgb.config(b2) <- xgb.config(bst)
|
||||||
expect_equal(bst$raw, b2$raw)
|
expect_equal(bst$raw, b2$raw)
|
||||||
|
|
||||||
for (f in files) if (file.exists(f)) file.remove(f)
|
for (f in files) if (file.exists(f)) file.remove(f)
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -211,7 +215,7 @@ test_that("early stopping xgb.train works", {
|
|||||||
err_pred <- err(ltest, pred)
|
err_pred <- err(ltest, pred)
|
||||||
err_log <- bst$evaluation_log[bst$best_iteration, test_error]
|
err_log <- bst$evaluation_log[bst$best_iteration, test_error]
|
||||||
expect_equal(err_log, err_pred, tolerance = 5e-6)
|
expect_equal(err_log, err_pred, tolerance = 5e-6)
|
||||||
|
|
||||||
set.seed(11)
|
set.seed(11)
|
||||||
expect_silent(
|
expect_silent(
|
||||||
bst0 <- xgb.train(param, dtrain, nrounds = 20, watchlist, eta = 0.3,
|
bst0 <- xgb.train(param, dtrain, nrounds = 20, watchlist, eta = 0.3,
|
||||||
@ -288,13 +292,13 @@ test_that("prediction in early-stopping xgb.cv works", {
|
|||||||
early_stopping_rounds = 5, maximize = FALSE, stratified = FALSE,
|
early_stopping_rounds = 5, maximize = FALSE, stratified = FALSE,
|
||||||
prediction = TRUE)
|
prediction = TRUE)
|
||||||
, "Stopping. Best iteration")
|
, "Stopping. Best iteration")
|
||||||
|
|
||||||
expect_false(is.null(cv$best_iteration))
|
expect_false(is.null(cv$best_iteration))
|
||||||
expect_lt(cv$best_iteration, 19)
|
expect_lt(cv$best_iteration, 19)
|
||||||
expect_false(is.null(cv$evaluation_log))
|
expect_false(is.null(cv$evaluation_log))
|
||||||
expect_false(is.null(cv$pred))
|
expect_false(is.null(cv$pred))
|
||||||
expect_length(cv$pred, nrow(train$data))
|
expect_length(cv$pred, nrow(train$data))
|
||||||
|
|
||||||
err_pred <- mean( sapply(cv$folds, function(f) mean(err(ltrain[f], cv$pred[f]))) )
|
err_pred <- mean( sapply(cv$folds, function(f) mean(err(ltrain[f], cv$pred[f]))) )
|
||||||
err_log <- cv$evaluation_log[cv$best_iteration, test_error_mean]
|
err_log <- cv$evaluation_log[cv$best_iteration, test_error_mean]
|
||||||
expect_equal(err_pred, err_log, tolerance = 1e-6)
|
expect_equal(err_pred, err_log, tolerance = 1e-6)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user