[R] remove default values in internal booster manipulation functions (#9461)

This commit is contained in:
James Lamb 2023-08-11 02:07:18 -05:00 committed by GitHub
parent d638535581
commit 428f6cbbe2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 32 additions and 13 deletions

View File

@ -511,7 +511,7 @@ cb.cv.predict <- function(save_models = FALSE) {
if (save_models) { if (save_models) {
env$basket$models <- lapply(env$bst_folds, function(fd) { env$basket$models <- lapply(env$bst_folds, function(fd) {
xgb.attr(fd$bst, 'niter') <- env$end_iteration - 1 xgb.attr(fd$bst, 'niter') <- env$end_iteration - 1
xgb.Booster.complete(xgb.handleToBooster(fd$bst), saveraw = TRUE) xgb.Booster.complete(xgb.handleToBooster(handle = fd$bst, raw = NULL), saveraw = TRUE)
}) })
} }
} }
@ -659,7 +659,7 @@ cb.gblinear.history <- function(sparse = FALSE) {
} else { # xgb.cv: } else { # xgb.cv:
cf <- vector("list", length(env$bst_folds)) cf <- vector("list", length(env$bst_folds))
for (i in seq_along(env$bst_folds)) { for (i in seq_along(env$bst_folds)) {
dmp <- xgb.dump(xgb.handleToBooster(env$bst_folds[[i]]$bst)) dmp <- xgb.dump(xgb.handleToBooster(handle = env$bst_folds[[i]]$bst, raw = NULL))
cf[[i]] <- as.numeric(grep('(booster|bias|weigh)', dmp, invert = TRUE, value = TRUE)) cf[[i]] <- as.numeric(grep('(booster|bias|weigh)', dmp, invert = TRUE, value = TRUE))
if (sparse) cf[[i]] <- as(cf[[i]], "sparseVector") if (sparse) cf[[i]] <- as(cf[[i]], "sparseVector")
} }

View File

@ -1,7 +1,6 @@
# Construct an internal xgboost Booster and return a handle to it. # Construct an internal xgboost Booster and return a handle to it.
# internal utility function # internal utility function
xgb.Booster.handle <- function(params = list(), cachelist = list(), xgb.Booster.handle <- function(params, cachelist, modelfile, handle) {
modelfile = NULL, handle = NULL) {
if (typeof(cachelist) != "list" || if (typeof(cachelist) != "list" ||
!all(vapply(cachelist, inherits, logical(1), what = 'xgb.DMatrix'))) { !all(vapply(cachelist, inherits, logical(1), what = 'xgb.DMatrix'))) {
stop("cachelist must be a list of xgb.DMatrix objects") stop("cachelist must be a list of xgb.DMatrix objects")
@ -44,7 +43,7 @@ xgb.Booster.handle <- function(params = list(), cachelist = list(),
# Convert xgb.Booster.handle to xgb.Booster # Convert xgb.Booster.handle to xgb.Booster
# internal utility function # internal utility function
xgb.handleToBooster <- function(handle, raw = NULL) { xgb.handleToBooster <- function(handle, raw) {
bst <- list(handle = handle, raw = raw) bst <- list(handle = handle, raw = raw)
class(bst) <- "xgb.Booster" class(bst) <- "xgb.Booster"
return(bst) return(bst)
@ -129,7 +128,12 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) {
stop("argument type must be xgb.Booster") stop("argument type must be xgb.Booster")
if (is.null.handle(object$handle)) { if (is.null.handle(object$handle)) {
object$handle <- xgb.Booster.handle(modelfile = object$raw, handle = object$handle) object$handle <- xgb.Booster.handle(
params = list(),
cachelist = list(),
modelfile = object$raw,
handle = object$handle
)
} else { } else {
if (is.null(object$raw) && saveraw) { if (is.null(object$raw) && saveraw) {
object$raw <- xgb.serialize(object$handle) object$raw <- xgb.serialize(object$handle)
@ -475,7 +479,7 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
#' @export #' @export
predict.xgb.Booster.handle <- function(object, ...) { predict.xgb.Booster.handle <- function(object, ...) {
bst <- xgb.handleToBooster(object) bst <- xgb.handleToBooster(handle = object, raw = NULL)
ret <- predict(bst, ...) ret <- predict(bst, ...)
return(ret) return(ret)

View File

@ -202,7 +202,12 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
dtrain <- slice(dall, unlist(folds[-k])) dtrain <- slice(dall, unlist(folds[-k]))
else else
dtrain <- slice(dall, train_folds[[k]]) dtrain <- slice(dall, train_folds[[k]])
handle <- xgb.Booster.handle(params, list(dtrain, dtest)) handle <- xgb.Booster.handle(
params = params,
cachelist = list(dtrain, dtest),
modelfile = NULL,
handle = NULL
)
list(dtrain = dtrain, bst = handle, watchlist = list(train = dtrain, test = dtest), index = folds[[k]]) list(dtrain = dtrain, bst = handle, watchlist = list(train = dtrain, test = dtest), index = folds[[k]])
}) })
rm(dall) rm(dall)

View File

@ -35,7 +35,12 @@ xgb.load <- function(modelfile) {
if (is.null(modelfile)) if (is.null(modelfile))
stop("xgb.load: modelfile cannot be NULL") stop("xgb.load: modelfile cannot be NULL")
handle <- xgb.Booster.handle(modelfile = modelfile) handle <- xgb.Booster.handle(
params = list(),
cachelist = list(),
modelfile = modelfile,
handle = NULL
)
# re-use modelfile if it is raw so we do not need to serialize # re-use modelfile if it is raw so we do not need to serialize
if (typeof(modelfile) == "raw") { if (typeof(modelfile) == "raw") {
warning( warning(
@ -45,9 +50,9 @@ xgb.load <- function(modelfile) {
" `xgb.unserialize` instead. " " `xgb.unserialize` instead. "
) )
) )
bst <- xgb.handleToBooster(handle, modelfile) bst <- xgb.handleToBooster(handle = handle, raw = modelfile)
} else { } else {
bst <- xgb.handleToBooster(handle, NULL) bst <- xgb.handleToBooster(handle = handle, raw = NULL)
} }
bst <- xgb.Booster.complete(bst, saveraw = TRUE) bst <- xgb.Booster.complete(bst, saveraw = TRUE)
return(bst) return(bst)

View File

@ -363,8 +363,13 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(),
is_update <- NVL(params[['process_type']], '.') == 'update' is_update <- NVL(params[['process_type']], '.') == 'update'
# Construct a booster (either a new one or load from xgb_model) # Construct a booster (either a new one or load from xgb_model)
handle <- xgb.Booster.handle(params, append(watchlist, dtrain), xgb_model) handle <- xgb.Booster.handle(
bst <- xgb.handleToBooster(handle) params = params,
cachelist = append(watchlist, dtrain),
modelfile = xgb_model,
handle = NULL
)
bst <- xgb.handleToBooster(handle = handle, raw = NULL)
# extract parameters that can affect the relationship b/w #trees and #iterations # extract parameters that can affect the relationship b/w #trees and #iterations
num_class <- max(as.numeric(NVL(params[['num_class']], 1)), 1) num_class <- max(as.numeric(NVL(params[['num_class']], 1)), 1)