Update R handles in-place (#6903)

* update R handles in-place #fixes 6896

* update test to expect non-null handle

* remove unused variable

* fix failing tests

* solve linter complains
This commit is contained in:
david-cortes
2021-04-29 22:50:46 +03:00
committed by GitHub
parent 5472ef626c
commit 4e1a8b1fe5
8 changed files with 51 additions and 7 deletions

View File

@@ -1,7 +1,7 @@
# Construct an internal xgboost Booster and return a handle to it.
# internal utility function
xgb.Booster.handle <- function(params = list(), cachelist = list(),
modelfile = NULL) {
modelfile = NULL, handle = NULL) {
if (typeof(cachelist) != "list" ||
!all(vapply(cachelist, inherits, logical(1), what = 'xgb.DMatrix'))) {
stop("cachelist must be a list of xgb.DMatrix objects")
@@ -20,7 +20,7 @@ xgb.Booster.handle <- function(params = list(), cachelist = list(),
return(handle)
} else if (typeof(modelfile) == "raw") {
## A memory buffer
bst <- xgb.unserialize(modelfile)
bst <- xgb.unserialize(modelfile, handle)
xgb.parameters(bst) <- params
return (bst)
} else if (inherits(modelfile, "xgb.Booster")) {
@@ -129,7 +129,7 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) {
stop("argument type must be xgb.Booster")
if (is.null.handle(object$handle)) {
object$handle <- xgb.Booster.handle(modelfile = object$raw)
object$handle <- xgb.Booster.handle(modelfile = object$raw, handle = object$handle)
} else {
if (is.null(object$raw) && saveraw) {
object$raw <- xgb.serialize(object$handle)

View File

@@ -1,11 +1,21 @@
#' Load the instance back from \code{\link{xgb.serialize}}
#'
#' @param buffer the buffer containing booster instance saved by \code{\link{xgb.serialize}}
#' @param handle An \code{xgb.Booster.handle} object which will be overwritten with
#' the new deserialized object. Must be a null handle (e.g. when loading the model through
#' `readRDS`). If not provided, a new handle will be created.
#' @return An \code{xgb.Booster.handle} object.
#'
#' @export
xgb.unserialize <- function(buffer) {
xgb.unserialize <- function(buffer, handle = NULL) {
cachelist <- list()
handle <- .Call(XGBoosterCreate_R, cachelist)
if (is.null(handle)) {
handle <- .Call(XGBoosterCreate_R, cachelist)
} else {
if (!is.null.handle(handle))
stop("'handle' is not null/empty. Cannot overwrite existing handle.")
.Call(XGBoosterCreateInEmptyObj_R, cachelist, handle)
}
tryCatch(
.Call(XGBoosterUnserializeFromBuffer_R, handle, buffer),
error = function(e) {