Update R handles in-place (#6903)
* update R handles in-place #fixes 6896 * update test to expect non-null handle * remove unused variable * fix failing tests * solve linter complains
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# Construct an internal xgboost Booster and return a handle to it.
|
||||
# internal utility function
|
||||
xgb.Booster.handle <- function(params = list(), cachelist = list(),
|
||||
modelfile = NULL) {
|
||||
modelfile = NULL, handle = NULL) {
|
||||
if (typeof(cachelist) != "list" ||
|
||||
!all(vapply(cachelist, inherits, logical(1), what = 'xgb.DMatrix'))) {
|
||||
stop("cachelist must be a list of xgb.DMatrix objects")
|
||||
@@ -20,7 +20,7 @@ xgb.Booster.handle <- function(params = list(), cachelist = list(),
|
||||
return(handle)
|
||||
} else if (typeof(modelfile) == "raw") {
|
||||
## A memory buffer
|
||||
bst <- xgb.unserialize(modelfile)
|
||||
bst <- xgb.unserialize(modelfile, handle)
|
||||
xgb.parameters(bst) <- params
|
||||
return (bst)
|
||||
} else if (inherits(modelfile, "xgb.Booster")) {
|
||||
@@ -129,7 +129,7 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) {
|
||||
stop("argument type must be xgb.Booster")
|
||||
|
||||
if (is.null.handle(object$handle)) {
|
||||
object$handle <- xgb.Booster.handle(modelfile = object$raw)
|
||||
object$handle <- xgb.Booster.handle(modelfile = object$raw, handle = object$handle)
|
||||
} else {
|
||||
if (is.null(object$raw) && saveraw) {
|
||||
object$raw <- xgb.serialize(object$handle)
|
||||
|
||||
@@ -1,11 +1,21 @@
|
||||
#' Load the instance back from \code{\link{xgb.serialize}}
|
||||
#'
|
||||
#' @param buffer the buffer containing booster instance saved by \code{\link{xgb.serialize}}
|
||||
#' @param handle An \code{xgb.Booster.handle} object which will be overwritten with
|
||||
#' the new deserialized object. Must be a null handle (e.g. when loading the model through
|
||||
#' `readRDS`). If not provided, a new handle will be created.
|
||||
#' @return An \code{xgb.Booster.handle} object.
|
||||
#'
|
||||
#' @export
|
||||
xgb.unserialize <- function(buffer) {
|
||||
xgb.unserialize <- function(buffer, handle = NULL) {
|
||||
cachelist <- list()
|
||||
handle <- .Call(XGBoosterCreate_R, cachelist)
|
||||
if (is.null(handle)) {
|
||||
handle <- .Call(XGBoosterCreate_R, cachelist)
|
||||
} else {
|
||||
if (!is.null.handle(handle))
|
||||
stop("'handle' is not null/empty. Cannot overwrite existing handle.")
|
||||
.Call(XGBoosterCreateInEmptyObj_R, cachelist, handle)
|
||||
}
|
||||
tryCatch(
|
||||
.Call(XGBoosterUnserializeFromBuffer_R, handle, buffer),
|
||||
error = function(e) {
|
||||
|
||||
Reference in New Issue
Block a user