[R] Construct booster object in load.raw. (#7686)

This commit is contained in:
Jiaming Yuan 2022-02-24 10:06:18 +08:00 committed by GitHub
parent 89aa8ddf52
commit f60d95b0ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 14 deletions

View File

@ -3,12 +3,21 @@
#' User can generate raw memory buffer by calling xgb.save.raw #' User can generate raw memory buffer by calling xgb.save.raw
#' #'
#' @param buffer the buffer returned by xgb.save.raw #' @param buffer the buffer returned by xgb.save.raw
#' @param as_booster Return the loaded model as xgb.Booster instead of xgb.Booster.handle.
#' #'
#' @export #' @export
xgb.load.raw <- function(buffer) { xgb.load.raw <- function(buffer, as_booster = FALSE) {
cachelist <- list() cachelist <- list()
handle <- .Call(XGBoosterCreate_R, cachelist) handle <- .Call(XGBoosterCreate_R, cachelist)
.Call(XGBoosterLoadModelFromRaw_R, handle, buffer) .Call(XGBoosterLoadModelFromRaw_R, handle, buffer)
class(handle) <- "xgb.Booster.handle" class(handle) <- "xgb.Booster.handle"
return (handle)
if (as_booster) {
booster <- list(handle = handle, raw = NULL)
class(booster) <- "xgb.Booster"
booster <- xgb.Booster.complete(booster, saveraw = TRUE)
return(booster)
} else {
return (handle)
}
} }

View File

@ -4,10 +4,12 @@
\alias{xgb.load.raw} \alias{xgb.load.raw}
\title{Load serialised xgboost model from R's raw vector} \title{Load serialised xgboost model from R's raw vector}
\usage{ \usage{
xgb.load.raw(buffer) xgb.load.raw(buffer, as_booster = FALSE)
} }
\arguments{ \arguments{
\item{buffer}{the buffer returned by xgb.save.raw} \item{buffer}{the buffer returned by xgb.save.raw}
\item{as_booster}{Return the loaded model as xgb.Booster instead of xgb.Booster.handle.}
} }
\description{ \description{
User can generate raw memory buffer by calling xgb.save.raw User can generate raw memory buffer by calling xgb.save.raw

View File

@ -19,17 +19,8 @@ test_that("load/save raw works", {
ubj_bytes <- xgb.save.raw(booster, raw_format = "ubj") ubj_bytes <- xgb.save.raw(booster, raw_format = "ubj")
old_bytes <- xgb.save.raw(booster, raw_format = "deprecated") old_bytes <- xgb.save.raw(booster, raw_format = "deprecated")
from_json <- xgb.load.raw(json_bytes) from_json <- xgb.load.raw(json_bytes, as_booster = TRUE)
from_ubj <- xgb.load.raw(ubj_bytes) from_ubj <- xgb.load.raw(ubj_bytes, as_booster = TRUE)
## FIXME(jiamingy): Should we include these 3 lines into `xgb.load.raw`?
from_json <- list(handle = from_json, raw = NULL)
class(from_json) <- "xgb.Booster"
from_json <- xgb.Booster.complete(from_json, saveraw = TRUE)
from_ubj <- list(handle = from_ubj, raw = NULL)
class(from_ubj) <- "xgb.Booster"
from_ubj <- xgb.Booster.complete(from_ubj, saveraw = TRUE)
json2old <- xgb.save.raw(from_json, raw_format = "deprecated") json2old <- xgb.save.raw(from_json, raw_format = "deprecated")
ubj2old <- xgb.save.raw(from_ubj, raw_format = "deprecated") ubj2old <- xgb.save.raw(from_ubj, raw_format = "deprecated")