diff --git a/R-package/R/xgb.load.raw.R b/R-package/R/xgb.load.raw.R index 2a7d375a9..d531da6c9 100644 --- a/R-package/R/xgb.load.raw.R +++ b/R-package/R/xgb.load.raw.R @@ -3,12 +3,21 @@ #' User can generate raw memory buffer by calling xgb.save.raw #' #' @param buffer the buffer returned by xgb.save.raw +#' @param as_booster Return the loaded model as xgb.Booster instead of xgb.Booster.handle. #' #' @export -xgb.load.raw <- function(buffer) { +xgb.load.raw <- function(buffer, as_booster = FALSE) { cachelist <- list() handle <- .Call(XGBoosterCreate_R, cachelist) .Call(XGBoosterLoadModelFromRaw_R, handle, buffer) class(handle) <- "xgb.Booster.handle" - return (handle) + + if (as_booster) { + booster <- list(handle = handle, raw = NULL) + class(booster) <- "xgb.Booster" + booster <- xgb.Booster.complete(booster, saveraw = TRUE) + return(booster) + } else { + return (handle) + } } diff --git a/R-package/man/xgb.load.raw.Rd b/R-package/man/xgb.load.raw.Rd index f0248cd9e..0af890e69 100644 --- a/R-package/man/xgb.load.raw.Rd +++ b/R-package/man/xgb.load.raw.Rd @@ -4,10 +4,12 @@ \alias{xgb.load.raw} \title{Load serialised xgboost model from R's raw vector} \usage{ -xgb.load.raw(buffer) +xgb.load.raw(buffer, as_booster = FALSE) } \arguments{ \item{buffer}{the buffer returned by xgb.save.raw} + +\item{as_booster}{Return the loaded model as xgb.Booster instead of xgb.Booster.handle.} } \description{ User can generate raw memory buffer by calling xgb.save.raw diff --git a/R-package/tests/testthat/test_io.R b/R-package/tests/testthat/test_io.R index f4990352f..5b2bc4265 100644 --- a/R-package/tests/testthat/test_io.R +++ b/R-package/tests/testthat/test_io.R @@ -19,17 +19,8 @@ test_that("load/save raw works", { ubj_bytes <- xgb.save.raw(booster, raw_format = "ubj") old_bytes <- xgb.save.raw(booster, raw_format = "deprecated") - from_json <- xgb.load.raw(json_bytes) - from_ubj <- xgb.load.raw(ubj_bytes) - - ## FIXME(jiamingy): Should we include these 3 lines into `xgb.load.raw`? - from_json <- list(handle = from_json, raw = NULL) - class(from_json) <- "xgb.Booster" - from_json <- xgb.Booster.complete(from_json, saveraw = TRUE) - - from_ubj <- list(handle = from_ubj, raw = NULL) - class(from_ubj) <- "xgb.Booster" - from_ubj <- xgb.Booster.complete(from_ubj, saveraw = TRUE) + from_json <- xgb.load.raw(json_bytes, as_booster = TRUE) + from_ubj <- xgb.load.raw(ubj_bytes, as_booster = TRUE) json2old <- xgb.save.raw(from_json, raw_format = "deprecated") ubj2old <- xgb.save.raw(from_ubj, raw_format = "deprecated")