[R] remove default values in internal booster manipulation functions (#9461)
This commit is contained in:
parent
d638535581
commit
428f6cbbe2
@ -511,7 +511,7 @@ cb.cv.predict <- function(save_models = FALSE) {
|
||||
if (save_models) {
|
||||
env$basket$models <- lapply(env$bst_folds, function(fd) {
|
||||
xgb.attr(fd$bst, 'niter') <- env$end_iteration - 1
|
||||
xgb.Booster.complete(xgb.handleToBooster(fd$bst), saveraw = TRUE)
|
||||
xgb.Booster.complete(xgb.handleToBooster(handle = fd$bst, raw = NULL), saveraw = TRUE)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -659,7 +659,7 @@ cb.gblinear.history <- function(sparse = FALSE) {
|
||||
} else { # xgb.cv:
|
||||
cf <- vector("list", length(env$bst_folds))
|
||||
for (i in seq_along(env$bst_folds)) {
|
||||
dmp <- xgb.dump(xgb.handleToBooster(env$bst_folds[[i]]$bst))
|
||||
dmp <- xgb.dump(xgb.handleToBooster(handle = env$bst_folds[[i]]$bst, raw = NULL))
|
||||
cf[[i]] <- as.numeric(grep('(booster|bias|weigh)', dmp, invert = TRUE, value = TRUE))
|
||||
if (sparse) cf[[i]] <- as(cf[[i]], "sparseVector")
|
||||
}
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
# Construct an internal xgboost Booster and return a handle to it.
|
||||
# internal utility function
|
||||
xgb.Booster.handle <- function(params = list(), cachelist = list(),
|
||||
modelfile = NULL, handle = NULL) {
|
||||
xgb.Booster.handle <- function(params, cachelist, modelfile, handle) {
|
||||
if (typeof(cachelist) != "list" ||
|
||||
!all(vapply(cachelist, inherits, logical(1), what = 'xgb.DMatrix'))) {
|
||||
stop("cachelist must be a list of xgb.DMatrix objects")
|
||||
@ -44,7 +43,7 @@ xgb.Booster.handle <- function(params = list(), cachelist = list(),
|
||||
|
||||
# Convert xgb.Booster.handle to xgb.Booster
|
||||
# internal utility function
|
||||
xgb.handleToBooster <- function(handle, raw = NULL) {
|
||||
xgb.handleToBooster <- function(handle, raw) {
|
||||
bst <- list(handle = handle, raw = raw)
|
||||
class(bst) <- "xgb.Booster"
|
||||
return(bst)
|
||||
@ -129,7 +128,12 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) {
|
||||
stop("argument type must be xgb.Booster")
|
||||
|
||||
if (is.null.handle(object$handle)) {
|
||||
object$handle <- xgb.Booster.handle(modelfile = object$raw, handle = object$handle)
|
||||
object$handle <- xgb.Booster.handle(
|
||||
params = list(),
|
||||
cachelist = list(),
|
||||
modelfile = object$raw,
|
||||
handle = object$handle
|
||||
)
|
||||
} else {
|
||||
if (is.null(object$raw) && saveraw) {
|
||||
object$raw <- xgb.serialize(object$handle)
|
||||
@ -475,7 +479,7 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
|
||||
#' @export
|
||||
predict.xgb.Booster.handle <- function(object, ...) {
|
||||
|
||||
bst <- xgb.handleToBooster(object)
|
||||
bst <- xgb.handleToBooster(handle = object, raw = NULL)
|
||||
|
||||
ret <- predict(bst, ...)
|
||||
return(ret)
|
||||
|
||||
@ -202,7 +202,12 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
|
||||
dtrain <- slice(dall, unlist(folds[-k]))
|
||||
else
|
||||
dtrain <- slice(dall, train_folds[[k]])
|
||||
handle <- xgb.Booster.handle(params, list(dtrain, dtest))
|
||||
handle <- xgb.Booster.handle(
|
||||
params = params,
|
||||
cachelist = list(dtrain, dtest),
|
||||
modelfile = NULL,
|
||||
handle = NULL
|
||||
)
|
||||
list(dtrain = dtrain, bst = handle, watchlist = list(train = dtrain, test = dtest), index = folds[[k]])
|
||||
})
|
||||
rm(dall)
|
||||
|
||||
@ -35,7 +35,12 @@ xgb.load <- function(modelfile) {
|
||||
if (is.null(modelfile))
|
||||
stop("xgb.load: modelfile cannot be NULL")
|
||||
|
||||
handle <- xgb.Booster.handle(modelfile = modelfile)
|
||||
handle <- xgb.Booster.handle(
|
||||
params = list(),
|
||||
cachelist = list(),
|
||||
modelfile = modelfile,
|
||||
handle = NULL
|
||||
)
|
||||
# re-use modelfile if it is raw so we do not need to serialize
|
||||
if (typeof(modelfile) == "raw") {
|
||||
warning(
|
||||
@ -45,9 +50,9 @@ xgb.load <- function(modelfile) {
|
||||
" `xgb.unserialize` instead. "
|
||||
)
|
||||
)
|
||||
bst <- xgb.handleToBooster(handle, modelfile)
|
||||
bst <- xgb.handleToBooster(handle = handle, raw = modelfile)
|
||||
} else {
|
||||
bst <- xgb.handleToBooster(handle, NULL)
|
||||
bst <- xgb.handleToBooster(handle = handle, raw = NULL)
|
||||
}
|
||||
bst <- xgb.Booster.complete(bst, saveraw = TRUE)
|
||||
return(bst)
|
||||
|
||||
@ -363,8 +363,13 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(),
|
||||
is_update <- NVL(params[['process_type']], '.') == 'update'
|
||||
|
||||
# Construct a booster (either a new one or load from xgb_model)
|
||||
handle <- xgb.Booster.handle(params, append(watchlist, dtrain), xgb_model)
|
||||
bst <- xgb.handleToBooster(handle)
|
||||
handle <- xgb.Booster.handle(
|
||||
params = params,
|
||||
cachelist = append(watchlist, dtrain),
|
||||
modelfile = xgb_model,
|
||||
handle = NULL
|
||||
)
|
||||
bst <- xgb.handleToBooster(handle = handle, raw = NULL)
|
||||
|
||||
# extract parameters that can affect the relationship b/w #trees and #iterations
|
||||
num_class <- max(as.numeric(NVL(params[['num_class']], 1)), 1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user