[R] remove default values in internal utility functions (#9457)
This commit is contained in:
parent
9dbb71490c
commit
44bd2981b2
@ -140,7 +140,7 @@ check.custom.eval <- function(env = parent.frame()) {
|
|||||||
|
|
||||||
|
|
||||||
# Update a booster handle for an iteration with dtrain data
|
# Update a booster handle for an iteration with dtrain data
|
||||||
xgb.iter.update <- function(booster_handle, dtrain, iter, obj = NULL) {
|
xgb.iter.update <- function(booster_handle, dtrain, iter, obj) {
|
||||||
if (!identical(class(booster_handle), "xgb.Booster.handle")) {
|
if (!identical(class(booster_handle), "xgb.Booster.handle")) {
|
||||||
stop("booster_handle must be of xgb.Booster.handle class")
|
stop("booster_handle must be of xgb.Booster.handle class")
|
||||||
}
|
}
|
||||||
@ -163,7 +163,7 @@ xgb.iter.update <- function(booster_handle, dtrain, iter, obj = NULL) {
|
|||||||
# Evaluate one iteration.
|
# Evaluate one iteration.
|
||||||
# Returns a named vector of evaluation metrics
|
# Returns a named vector of evaluation metrics
|
||||||
# with the names in a 'datasetname-metricname' format.
|
# with the names in a 'datasetname-metricname' format.
|
||||||
xgb.iter.eval <- function(booster_handle, watchlist, iter, feval = NULL) {
|
xgb.iter.eval <- function(booster_handle, watchlist, iter, feval) {
|
||||||
if (!identical(class(booster_handle), "xgb.Booster.handle"))
|
if (!identical(class(booster_handle), "xgb.Booster.handle"))
|
||||||
stop("class of booster_handle must be xgb.Booster.handle")
|
stop("class of booster_handle must be xgb.Booster.handle")
|
||||||
|
|
||||||
@ -234,7 +234,7 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
|
|||||||
y <- factor(y)
|
y <- factor(y)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
folds <- xgb.createFolds(y, nfold)
|
folds <- xgb.createFolds(y = y, k = nfold)
|
||||||
} else {
|
} else {
|
||||||
# make simple non-stratified folds
|
# make simple non-stratified folds
|
||||||
kstep <- length(rnd_idx) %/% nfold
|
kstep <- length(rnd_idx) %/% nfold
|
||||||
@ -251,7 +251,7 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
|
|||||||
# Creates CV folds stratified by the values of y.
|
# Creates CV folds stratified by the values of y.
|
||||||
# It was borrowed from caret::createFolds and simplified
|
# It was borrowed from caret::createFolds and simplified
|
||||||
# by always returning an unnamed list of fold indices.
|
# by always returning an unnamed list of fold indices.
|
||||||
xgb.createFolds <- function(y, k = 10) {
|
xgb.createFolds <- function(y, k) {
|
||||||
if (is.numeric(y)) {
|
if (is.numeric(y)) {
|
||||||
## Group the numeric data based on their magnitudes
|
## Group the numeric data based on their magnitudes
|
||||||
## and sample within those groups.
|
## and sample within those groups.
|
||||||
|
|||||||
@ -223,8 +223,18 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
|
|||||||
for (f in cb$pre_iter) f()
|
for (f in cb$pre_iter) f()
|
||||||
|
|
||||||
msg <- lapply(bst_folds, function(fd) {
|
msg <- lapply(bst_folds, function(fd) {
|
||||||
xgb.iter.update(fd$bst, fd$dtrain, iteration - 1, obj)
|
xgb.iter.update(
|
||||||
xgb.iter.eval(fd$bst, fd$watchlist, iteration - 1, feval)
|
booster_handle = fd$bst,
|
||||||
|
dtrain = fd$dtrain,
|
||||||
|
iter = iteration - 1,
|
||||||
|
obj = obj
|
||||||
|
)
|
||||||
|
xgb.iter.eval(
|
||||||
|
booster_handle = fd$bst,
|
||||||
|
watchlist = fd$watchlist,
|
||||||
|
iter = iteration - 1,
|
||||||
|
feval = feval
|
||||||
|
)
|
||||||
})
|
})
|
||||||
msg <- simplify2array(msg)
|
msg <- simplify2array(msg)
|
||||||
bst_evaluation <- rowMeans(msg)
|
bst_evaluation <- rowMeans(msg)
|
||||||
|
|||||||
@ -390,10 +390,21 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(),
|
|||||||
|
|
||||||
for (f in cb$pre_iter) f()
|
for (f in cb$pre_iter) f()
|
||||||
|
|
||||||
xgb.iter.update(bst$handle, dtrain, iteration - 1, obj)
|
xgb.iter.update(
|
||||||
|
booster_handle = bst$handle,
|
||||||
|
dtrain = dtrain,
|
||||||
|
iter = iteration - 1,
|
||||||
|
obj = obj
|
||||||
|
)
|
||||||
|
|
||||||
if (length(watchlist) > 0)
|
if (length(watchlist) > 0) {
|
||||||
bst_evaluation <- xgb.iter.eval(bst$handle, watchlist, iteration - 1, feval) # nolint: object_usage_linter
|
bst_evaluation <- xgb.iter.eval( # nolint: object_usage_linter
|
||||||
|
booster_handle = bst$handle,
|
||||||
|
watchlist = watchlist,
|
||||||
|
iter = iteration - 1,
|
||||||
|
feval = feval
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
xgb.attr(bst$handle, 'niter') <- iteration - 1
|
xgb.attr(bst$handle, 'niter') <- iteration - 1
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user