[R] remove default values in internal booster manipulation functions (#9461)
This commit is contained in:
parent
d638535581
commit
428f6cbbe2
@ -511,7 +511,7 @@ cb.cv.predict <- function(save_models = FALSE) {
|
|||||||
if (save_models) {
|
if (save_models) {
|
||||||
env$basket$models <- lapply(env$bst_folds, function(fd) {
|
env$basket$models <- lapply(env$bst_folds, function(fd) {
|
||||||
xgb.attr(fd$bst, 'niter') <- env$end_iteration - 1
|
xgb.attr(fd$bst, 'niter') <- env$end_iteration - 1
|
||||||
xgb.Booster.complete(xgb.handleToBooster(fd$bst), saveraw = TRUE)
|
xgb.Booster.complete(xgb.handleToBooster(handle = fd$bst, raw = NULL), saveraw = TRUE)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -659,7 +659,7 @@ cb.gblinear.history <- function(sparse = FALSE) {
|
|||||||
} else { # xgb.cv:
|
} else { # xgb.cv:
|
||||||
cf <- vector("list", length(env$bst_folds))
|
cf <- vector("list", length(env$bst_folds))
|
||||||
for (i in seq_along(env$bst_folds)) {
|
for (i in seq_along(env$bst_folds)) {
|
||||||
dmp <- xgb.dump(xgb.handleToBooster(env$bst_folds[[i]]$bst))
|
dmp <- xgb.dump(xgb.handleToBooster(handle = env$bst_folds[[i]]$bst, raw = NULL))
|
||||||
cf[[i]] <- as.numeric(grep('(booster|bias|weigh)', dmp, invert = TRUE, value = TRUE))
|
cf[[i]] <- as.numeric(grep('(booster|bias|weigh)', dmp, invert = TRUE, value = TRUE))
|
||||||
if (sparse) cf[[i]] <- as(cf[[i]], "sparseVector")
|
if (sparse) cf[[i]] <- as(cf[[i]], "sparseVector")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
# Construct an internal xgboost Booster and return a handle to it.
|
# Construct an internal xgboost Booster and return a handle to it.
|
||||||
# internal utility function
|
# internal utility function
|
||||||
xgb.Booster.handle <- function(params = list(), cachelist = list(),
|
xgb.Booster.handle <- function(params, cachelist, modelfile, handle) {
|
||||||
modelfile = NULL, handle = NULL) {
|
|
||||||
if (typeof(cachelist) != "list" ||
|
if (typeof(cachelist) != "list" ||
|
||||||
!all(vapply(cachelist, inherits, logical(1), what = 'xgb.DMatrix'))) {
|
!all(vapply(cachelist, inherits, logical(1), what = 'xgb.DMatrix'))) {
|
||||||
stop("cachelist must be a list of xgb.DMatrix objects")
|
stop("cachelist must be a list of xgb.DMatrix objects")
|
||||||
@ -44,7 +43,7 @@ xgb.Booster.handle <- function(params = list(), cachelist = list(),
|
|||||||
|
|
||||||
# Convert xgb.Booster.handle to xgb.Booster
|
# Convert xgb.Booster.handle to xgb.Booster
|
||||||
# internal utility function
|
# internal utility function
|
||||||
xgb.handleToBooster <- function(handle, raw = NULL) {
|
xgb.handleToBooster <- function(handle, raw) {
|
||||||
bst <- list(handle = handle, raw = raw)
|
bst <- list(handle = handle, raw = raw)
|
||||||
class(bst) <- "xgb.Booster"
|
class(bst) <- "xgb.Booster"
|
||||||
return(bst)
|
return(bst)
|
||||||
@ -129,7 +128,12 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) {
|
|||||||
stop("argument type must be xgb.Booster")
|
stop("argument type must be xgb.Booster")
|
||||||
|
|
||||||
if (is.null.handle(object$handle)) {
|
if (is.null.handle(object$handle)) {
|
||||||
object$handle <- xgb.Booster.handle(modelfile = object$raw, handle = object$handle)
|
object$handle <- xgb.Booster.handle(
|
||||||
|
params = list(),
|
||||||
|
cachelist = list(),
|
||||||
|
modelfile = object$raw,
|
||||||
|
handle = object$handle
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
if (is.null(object$raw) && saveraw) {
|
if (is.null(object$raw) && saveraw) {
|
||||||
object$raw <- xgb.serialize(object$handle)
|
object$raw <- xgb.serialize(object$handle)
|
||||||
@ -475,7 +479,7 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
|
|||||||
#' @export
|
#' @export
|
||||||
predict.xgb.Booster.handle <- function(object, ...) {
|
predict.xgb.Booster.handle <- function(object, ...) {
|
||||||
|
|
||||||
bst <- xgb.handleToBooster(object)
|
bst <- xgb.handleToBooster(handle = object, raw = NULL)
|
||||||
|
|
||||||
ret <- predict(bst, ...)
|
ret <- predict(bst, ...)
|
||||||
return(ret)
|
return(ret)
|
||||||
|
|||||||
@ -202,7 +202,12 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
|
|||||||
dtrain <- slice(dall, unlist(folds[-k]))
|
dtrain <- slice(dall, unlist(folds[-k]))
|
||||||
else
|
else
|
||||||
dtrain <- slice(dall, train_folds[[k]])
|
dtrain <- slice(dall, train_folds[[k]])
|
||||||
handle <- xgb.Booster.handle(params, list(dtrain, dtest))
|
handle <- xgb.Booster.handle(
|
||||||
|
params = params,
|
||||||
|
cachelist = list(dtrain, dtest),
|
||||||
|
modelfile = NULL,
|
||||||
|
handle = NULL
|
||||||
|
)
|
||||||
list(dtrain = dtrain, bst = handle, watchlist = list(train = dtrain, test = dtest), index = folds[[k]])
|
list(dtrain = dtrain, bst = handle, watchlist = list(train = dtrain, test = dtest), index = folds[[k]])
|
||||||
})
|
})
|
||||||
rm(dall)
|
rm(dall)
|
||||||
|
|||||||
@ -35,7 +35,12 @@ xgb.load <- function(modelfile) {
|
|||||||
if (is.null(modelfile))
|
if (is.null(modelfile))
|
||||||
stop("xgb.load: modelfile cannot be NULL")
|
stop("xgb.load: modelfile cannot be NULL")
|
||||||
|
|
||||||
handle <- xgb.Booster.handle(modelfile = modelfile)
|
handle <- xgb.Booster.handle(
|
||||||
|
params = list(),
|
||||||
|
cachelist = list(),
|
||||||
|
modelfile = modelfile,
|
||||||
|
handle = NULL
|
||||||
|
)
|
||||||
# re-use modelfile if it is raw so we do not need to serialize
|
# re-use modelfile if it is raw so we do not need to serialize
|
||||||
if (typeof(modelfile) == "raw") {
|
if (typeof(modelfile) == "raw") {
|
||||||
warning(
|
warning(
|
||||||
@ -45,9 +50,9 @@ xgb.load <- function(modelfile) {
|
|||||||
" `xgb.unserialize` instead. "
|
" `xgb.unserialize` instead. "
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
bst <- xgb.handleToBooster(handle, modelfile)
|
bst <- xgb.handleToBooster(handle = handle, raw = modelfile)
|
||||||
} else {
|
} else {
|
||||||
bst <- xgb.handleToBooster(handle, NULL)
|
bst <- xgb.handleToBooster(handle = handle, raw = NULL)
|
||||||
}
|
}
|
||||||
bst <- xgb.Booster.complete(bst, saveraw = TRUE)
|
bst <- xgb.Booster.complete(bst, saveraw = TRUE)
|
||||||
return(bst)
|
return(bst)
|
||||||
|
|||||||
@ -363,8 +363,13 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(),
|
|||||||
is_update <- NVL(params[['process_type']], '.') == 'update'
|
is_update <- NVL(params[['process_type']], '.') == 'update'
|
||||||
|
|
||||||
# Construct a booster (either a new one or load from xgb_model)
|
# Construct a booster (either a new one or load from xgb_model)
|
||||||
handle <- xgb.Booster.handle(params, append(watchlist, dtrain), xgb_model)
|
handle <- xgb.Booster.handle(
|
||||||
bst <- xgb.handleToBooster(handle)
|
params = params,
|
||||||
|
cachelist = append(watchlist, dtrain),
|
||||||
|
modelfile = xgb_model,
|
||||||
|
handle = NULL
|
||||||
|
)
|
||||||
|
bst <- xgb.handleToBooster(handle = handle, raw = NULL)
|
||||||
|
|
||||||
# extract parameters that can affect the relationship b/w #trees and #iterations
|
# extract parameters that can affect the relationship b/w #trees and #iterations
|
||||||
num_class <- max(as.numeric(NVL(params[['num_class']], 1)), 1)
|
num_class <- max(as.numeric(NVL(params[['num_class']], 1)), 1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user