Update R handles in-place (#6903)

* update R handles in-place #fixes 6896

* update test to expect non-null handle

* remove unused variable

* fix failing tests

* solve linter complains
This commit is contained in:
david-cortes 2021-04-29 22:50:46 +03:00 committed by GitHub
parent 5472ef626c
commit 4e1a8b1fe5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 51 additions and 7 deletions

View File

@ -1,7 +1,7 @@
# Construct an internal xgboost Booster and return a handle to it. # Construct an internal xgboost Booster and return a handle to it.
# internal utility function # internal utility function
xgb.Booster.handle <- function(params = list(), cachelist = list(), xgb.Booster.handle <- function(params = list(), cachelist = list(),
modelfile = NULL) { modelfile = NULL, handle = NULL) {
if (typeof(cachelist) != "list" || if (typeof(cachelist) != "list" ||
!all(vapply(cachelist, inherits, logical(1), what = 'xgb.DMatrix'))) { !all(vapply(cachelist, inherits, logical(1), what = 'xgb.DMatrix'))) {
stop("cachelist must be a list of xgb.DMatrix objects") stop("cachelist must be a list of xgb.DMatrix objects")
@ -20,7 +20,7 @@ xgb.Booster.handle <- function(params = list(), cachelist = list(),
return(handle) return(handle)
} else if (typeof(modelfile) == "raw") { } else if (typeof(modelfile) == "raw") {
## A memory buffer ## A memory buffer
bst <- xgb.unserialize(modelfile) bst <- xgb.unserialize(modelfile, handle)
xgb.parameters(bst) <- params xgb.parameters(bst) <- params
return (bst) return (bst)
} else if (inherits(modelfile, "xgb.Booster")) { } else if (inherits(modelfile, "xgb.Booster")) {
@ -129,7 +129,7 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) {
stop("argument type must be xgb.Booster") stop("argument type must be xgb.Booster")
if (is.null.handle(object$handle)) { if (is.null.handle(object$handle)) {
object$handle <- xgb.Booster.handle(modelfile = object$raw) object$handle <- xgb.Booster.handle(modelfile = object$raw, handle = object$handle)
} else { } else {
if (is.null(object$raw) && saveraw) { if (is.null(object$raw) && saveraw) {
object$raw <- xgb.serialize(object$handle) object$raw <- xgb.serialize(object$handle)

View File

@ -1,11 +1,21 @@
#' Load the instance back from \code{\link{xgb.serialize}} #' Load the instance back from \code{\link{xgb.serialize}}
#' #'
#' @param buffer the buffer containing booster instance saved by \code{\link{xgb.serialize}} #' @param buffer the buffer containing booster instance saved by \code{\link{xgb.serialize}}
#' @param handle An \code{xgb.Booster.handle} object which will be overwritten with
#' the new deserialized object. Must be a null handle (e.g. when loading the model through
#' `readRDS`). If not provided, a new handle will be created.
#' @return An \code{xgb.Booster.handle} object.
#' #'
#' @export #' @export
xgb.unserialize <- function(buffer) { xgb.unserialize <- function(buffer, handle = NULL) {
cachelist <- list() cachelist <- list()
handle <- .Call(XGBoosterCreate_R, cachelist) if (is.null(handle)) {
handle <- .Call(XGBoosterCreate_R, cachelist)
} else {
if (!is.null.handle(handle))
stop("'handle' is not null/empty. Cannot overwrite existing handle.")
.Call(XGBoosterCreateInEmptyObj_R, cachelist, handle)
}
tryCatch( tryCatch(
.Call(XGBoosterUnserializeFromBuffer_R, handle, buffer), .Call(XGBoosterUnserializeFromBuffer_R, handle, buffer),
error = function(e) { error = function(e) {

View File

@ -4,10 +4,17 @@
\alias{xgb.unserialize} \alias{xgb.unserialize}
\title{Load the instance back from \code{\link{xgb.serialize}}} \title{Load the instance back from \code{\link{xgb.serialize}}}
\usage{ \usage{
xgb.unserialize(buffer) xgb.unserialize(buffer, handle = NULL)
} }
\arguments{ \arguments{
\item{buffer}{the buffer containing booster instance saved by \code{\link{xgb.serialize}}} \item{buffer}{the buffer containing booster instance saved by \code{\link{xgb.serialize}}}
\item{handle}{An \code{xgb.Booster.handle} object which will be overwritten with
the new deserialized object. Must be a null handle (e.g. when loading the model through
`readRDS`). If not provided, a new handle will be created.}
}
\value{
An \code{xgb.Booster.handle} object.
} }
\description{ \description{
Load the instance back from \code{\link{xgb.serialize}} Load the instance back from \code{\link{xgb.serialize}}

View File

@ -17,6 +17,7 @@ Check these declarations against the C/Fortran source code.
/* .Call calls */ /* .Call calls */
extern SEXP XGBoosterBoostOneIter_R(SEXP, SEXP, SEXP, SEXP); extern SEXP XGBoosterBoostOneIter_R(SEXP, SEXP, SEXP, SEXP);
extern SEXP XGBoosterCreate_R(SEXP); extern SEXP XGBoosterCreate_R(SEXP);
extern SEXP XGBoosterCreateInEmptyObj_R(SEXP, SEXP);
extern SEXP XGBoosterDumpModel_R(SEXP, SEXP, SEXP, SEXP); extern SEXP XGBoosterDumpModel_R(SEXP, SEXP, SEXP, SEXP);
extern SEXP XGBoosterEvalOneIter_R(SEXP, SEXP, SEXP, SEXP); extern SEXP XGBoosterEvalOneIter_R(SEXP, SEXP, SEXP, SEXP);
extern SEXP XGBoosterGetAttrNames_R(SEXP); extern SEXP XGBoosterGetAttrNames_R(SEXP);
@ -49,6 +50,7 @@ extern SEXP XGBGetGlobalConfig_R();
static const R_CallMethodDef CallEntries[] = { static const R_CallMethodDef CallEntries[] = {
{"XGBoosterBoostOneIter_R", (DL_FUNC) &XGBoosterBoostOneIter_R, 4}, {"XGBoosterBoostOneIter_R", (DL_FUNC) &XGBoosterBoostOneIter_R, 4},
{"XGBoosterCreate_R", (DL_FUNC) &XGBoosterCreate_R, 1}, {"XGBoosterCreate_R", (DL_FUNC) &XGBoosterCreate_R, 1},
{"XGBoosterCreateInEmptyObj_R", (DL_FUNC) &XGBoosterCreateInEmptyObj_R, 2},
{"XGBoosterDumpModel_R", (DL_FUNC) &XGBoosterDumpModel_R, 4}, {"XGBoosterDumpModel_R", (DL_FUNC) &XGBoosterDumpModel_R, 4},
{"XGBoosterEvalOneIter_R", (DL_FUNC) &XGBoosterEvalOneIter_R, 4}, {"XGBoosterEvalOneIter_R", (DL_FUNC) &XGBoosterEvalOneIter_R, 4},
{"XGBoosterGetAttrNames_R", (DL_FUNC) &XGBoosterGetAttrNames_R, 1}, {"XGBoosterGetAttrNames_R", (DL_FUNC) &XGBoosterGetAttrNames_R, 1},

View File

@ -272,6 +272,21 @@ SEXP XGBoosterCreate_R(SEXP dmats) {
return ret; return ret;
} }
SEXP XGBoosterCreateInEmptyObj_R(SEXP dmats, SEXP R_handle) {
R_API_BEGIN();
int len = length(dmats);
std::vector<void*> dvec;
for (int i = 0; i < len; ++i) {
dvec.push_back(R_ExternalPtrAddr(VECTOR_ELT(dmats, i)));
}
BoosterHandle handle;
CHECK_CALL(XGBoosterCreate(BeginPtr(dvec), dvec.size(), &handle));
R_SetExternalPtrAddr(R_handle, handle);
R_RegisterCFinalizerEx(R_handle, _BoosterFinalizer, TRUE);
R_API_END();
return R_NilValue;
}
SEXP XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val) { SEXP XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val) {
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(XGBoosterSetParam(R_ExternalPtrAddr(handle), CHECK_CALL(XGBoosterSetParam(R_ExternalPtrAddr(handle),

View File

@ -116,6 +116,14 @@ XGB_DLL SEXP XGDMatrixNumCol_R(SEXP handle);
*/ */
XGB_DLL SEXP XGBoosterCreate_R(SEXP dmats); XGB_DLL SEXP XGBoosterCreate_R(SEXP dmats);
/*!
* \brief create xgboost learner, saving the pointer into an existing R object
* \param dmats a list of dmatrix handles that will be cached
* \param R_handle a clean R external pointer (not holding any object)
*/
XGB_DLL SEXP XGBoosterCreateInEmptyObj_R(SEXP dmats, SEXP R_handle);
/*! /*!
* \brief set parameters * \brief set parameters
* \param handle handle * \param handle handle

View File

@ -238,12 +238,13 @@ if (grepl('Windows', Sys.info()[['sysname']]) ||
test_that("xgb.Booster serializing as R object works", { test_that("xgb.Booster serializing as R object works", {
saveRDS(bst.Tree, 'xgb.model.rds') saveRDS(bst.Tree, 'xgb.model.rds')
bst <- readRDS('xgb.model.rds') bst <- readRDS('xgb.model.rds')
if (file.exists('xgb.model.rds')) file.remove('xgb.model.rds')
dtrain <- xgb.DMatrix(sparse_matrix, label = label) dtrain <- xgb.DMatrix(sparse_matrix, label = label)
expect_equal(predict(bst.Tree, dtrain), predict(bst, dtrain), tolerance = float_tolerance) expect_equal(predict(bst.Tree, dtrain), predict(bst, dtrain), tolerance = float_tolerance)
expect_equal(xgb.dump(bst.Tree), xgb.dump(bst)) expect_equal(xgb.dump(bst.Tree), xgb.dump(bst))
xgb.save(bst, 'xgb.model') xgb.save(bst, 'xgb.model')
if (file.exists('xgb.model')) file.remove('xgb.model') if (file.exists('xgb.model')) file.remove('xgb.model')
bst <- readRDS('xgb.model.rds')
if (file.exists('xgb.model.rds')) file.remove('xgb.model.rds')
nil_ptr <- new("externalptr") nil_ptr <- new("externalptr")
class(nil_ptr) <- "xgb.Booster.handle" class(nil_ptr) <- "xgb.Booster.handle"
expect_true(identical(bst$handle, nil_ptr)) expect_true(identical(bst$handle, nil_ptr))

View File

@ -83,6 +83,7 @@ test_that("Models from previous versions of XGBoost can be loaded", {
if (is_rds && compareVersion(model_xgb_ver, '1.1.1.1') < 0) { if (is_rds && compareVersion(model_xgb_ver, '1.1.1.1') < 0) {
booster <- readRDS(model_file) booster <- readRDS(model_file)
expect_warning(predict(booster, newdata = pred_data)) expect_warning(predict(booster, newdata = pred_data))
booster <- readRDS(model_file)
expect_warning(run_booster_check(booster, name)) expect_warning(run_booster_check(booster, name))
} else { } else {
if (is_rds) { if (is_rds) {