[R] add parameter deprecation related utilities; code style
This commit is contained in:
parent
76650c096f
commit
c342614a81
@ -27,7 +27,7 @@ NVL <- function(x, val) {
|
|||||||
|
|
||||||
# Merges booster params with whatever is provided in ...
|
# Merges booster params with whatever is provided in ...
|
||||||
# plus runs some checks
|
# plus runs some checks
|
||||||
check.params <- function(params, ...) {
|
check.booster.params <- function(params, ...) {
|
||||||
if (typeof(params) != "list")
|
if (typeof(params) != "list")
|
||||||
stop("params must be a list")
|
stop("params must be a list")
|
||||||
|
|
||||||
@ -35,25 +35,26 @@ check.params <- function(params, ...) {
|
|||||||
names(params) <- gsub("\\.", "_", names(params))
|
names(params) <- gsub("\\.", "_", names(params))
|
||||||
|
|
||||||
# merge parameters from the params and the dots-expansion
|
# merge parameters from the params and the dots-expansion
|
||||||
dot.params <- list(...)
|
dot_params <- list(...)
|
||||||
names(dot.params) <- gsub("\\.", "_", names(dot.params))
|
names(dot_params) <- gsub("\\.", "_", names(dot_params))
|
||||||
if (length(intersect(names(params), names(dot.params))) > 0)
|
if (length(intersect(names(params),
|
||||||
|
names(dot_params))) > 0)
|
||||||
stop("Same parameters in 'params' and in the call are not allowed. Please check your 'params' list.")
|
stop("Same parameters in 'params' and in the call are not allowed. Please check your 'params' list.")
|
||||||
params <- c(params, dot.params)
|
params <- c(params, dot_params)
|
||||||
|
|
||||||
# only multiple eval_metric's make sense
|
# providing a parameter multiple times only makes sense for 'eval_metric'
|
||||||
name.freqs <- table(names(params))
|
name_freqs <- table(names(params))
|
||||||
multi.names <- setdiff( names(name.freqs[name.freqs>1]), 'eval_metric')
|
multi_names <- setdiff(names(name_freqs[name_freqs > 1]), 'eval_metric')
|
||||||
if (length(multi.names) > 0) {
|
if (length(multi_names) > 0) {
|
||||||
warning("The following parameters (other than 'eval_metric') were provided multiple times:\n\t",
|
warning("The following parameters were provided multiple times:\n\t",
|
||||||
paste(multi.names, collapse=', '), "\n Only the last value for each of them will be used.\n")
|
paste(multi_names, collapse=', '), "\n Only the last value for each of them will be used.\n")
|
||||||
# While xgboost itself would choose the last value for a multi-parameter,
|
# While xgboost itself would choose the last value for a multi-parameter,
|
||||||
# will do some clean-up here b/c multi-parameters could be used further in R code, and R would
|
# will do some clean-up here b/c multi-parameters could be used further in R code, and R would
|
||||||
# pick the 1st (not the last) value when multiple elements with the same name are present in a list.
|
# pick the 1st (not the last) value when multiple elements with the same name are present in a list.
|
||||||
for (n in multi.names) {
|
for (n in multi_names) {
|
||||||
del.idx <- which(n == names(params))
|
del_idx <- which(n == names(params))
|
||||||
del.idx <- del.idx[-length(del.idx)]
|
del_idx <- del_idx[-length(del_idx)]
|
||||||
params[[del.idx]] <- NULL
|
params[[del_idx]] <- NULL
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -99,7 +100,7 @@ check.custom.eval <- function(env = parent.frame()) {
|
|||||||
if (!is.null(env$feval) && is.null(env$maximize))
|
if (!is.null(env$feval) && is.null(env$maximize))
|
||||||
stop("Please set 'maximize' to indicate whether the metric needs to be maximized or not")
|
stop("Please set 'maximize' to indicate whether the metric needs to be maximized or not")
|
||||||
|
|
||||||
# handle the situation when custom eval function was provided through params
|
# handle a situation when custom eval function was provided through params
|
||||||
if (!is.null(env$params[['eval_metric']]) &&
|
if (!is.null(env$params[['eval_metric']]) &&
|
||||||
typeof(env$params$eval_metric) == 'closure') {
|
typeof(env$params$eval_metric) == 'closure') {
|
||||||
env$feval <- env$params$eval_metric
|
env$feval <- env$params$eval_metric
|
||||||
@ -136,7 +137,8 @@ xgb.iter.update <- function(booster, dtrain, iter, obj = NULL) {
|
|||||||
# with the names in a 'datasetname-metricname' format.
|
# with the names in a 'datasetname-metricname' format.
|
||||||
xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL) {
|
xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL) {
|
||||||
if (class(booster) != "xgb.Booster.handle")
|
if (class(booster) != "xgb.Booster.handle")
|
||||||
stop("first argument must be type xgb.Booster.handle")
|
stop("first argument type must be xgb.Booster.handle")
|
||||||
|
|
||||||
if (length(watchlist) == 0)
|
if (length(watchlist) == 0)
|
||||||
return(NULL)
|
return(NULL)
|
||||||
|
|
||||||
@ -171,15 +173,15 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
|
|||||||
# cannot do it for rank
|
# cannot do it for rank
|
||||||
if (exists('objective', where=params) &&
|
if (exists('objective', where=params) &&
|
||||||
is.character(params$objective) &&
|
is.character(params$objective) &&
|
||||||
strtrim(params$objective, 5) == 'rank:')
|
strtrim(params$objective, 5) == 'rank:') {
|
||||||
stop("\n\tAutomatic generation of CV-folds is not implemented for ranking!\n",
|
stop("\n\tAutomatic generation of CV-folds is not implemented for ranking!\n",
|
||||||
"\tConsider providing pre-computed CV-folds through the 'folds=' parameter.\n")
|
"\tConsider providing pre-computed CV-folds through the 'folds=' parameter.\n")
|
||||||
|
}
|
||||||
# shuffle
|
# shuffle
|
||||||
rnd.idx <- sample(1:nrows)
|
rnd_idx <- sample(1:nrows)
|
||||||
if (stratified &&
|
if (stratified &&
|
||||||
length(label) == length(rnd.idx)) {
|
length(label) == length(rnd_idx)) {
|
||||||
y <- label[rnd.idx]
|
y <- label[rnd_idx]
|
||||||
# WARNING: some heuristic logic is employed to identify classification setting!
|
# WARNING: some heuristic logic is employed to identify classification setting!
|
||||||
# - For classification, need to convert y labels to factor before making the folds,
|
# - For classification, need to convert y labels to factor before making the folds,
|
||||||
# and then do stratification by factor levels.
|
# and then do stratification by factor levels.
|
||||||
@ -200,13 +202,13 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
|
|||||||
folds <- xgb.createFolds(y, nfold)
|
folds <- xgb.createFolds(y, nfold)
|
||||||
} else {
|
} else {
|
||||||
# make simple non-stratified folds
|
# make simple non-stratified folds
|
||||||
kstep <- length(rnd.idx) %/% nfold
|
kstep <- length(rnd_idx) %/% nfold
|
||||||
folds <- list()
|
folds <- list()
|
||||||
for (i in 1:(nfold - 1)) {
|
for (i in 1:(nfold - 1)) {
|
||||||
folds[[i]] <- rnd.idx[1:kstep]
|
folds[[i]] <- rnd_idx[1:kstep]
|
||||||
rnd.idx <- rnd.idx[-(1:kstep)]
|
rnd_idx <- rnd_idx[-(1:kstep)]
|
||||||
}
|
}
|
||||||
folds[[nfold]] <- rnd.idx
|
folds[[nfold]] <- rnd_idx
|
||||||
}
|
}
|
||||||
return(folds)
|
return(folds)
|
||||||
}
|
}
|
||||||
@ -216,7 +218,7 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, params) {
|
|||||||
# by always returning an unnamed list of fold indices.
|
# by always returning an unnamed list of fold indices.
|
||||||
xgb.createFolds <- function(y, k = 10)
|
xgb.createFolds <- function(y, k = 10)
|
||||||
{
|
{
|
||||||
if(is.numeric(y)) {
|
if (is.numeric(y)) {
|
||||||
## Group the numeric data based on their magnitudes
|
## Group the numeric data based on their magnitudes
|
||||||
## and sample within those groups.
|
## and sample within those groups.
|
||||||
|
|
||||||
@ -235,7 +237,7 @@ xgb.createFolds <- function(y, k = 10)
|
|||||||
include.lowest = TRUE)
|
include.lowest = TRUE)
|
||||||
}
|
}
|
||||||
|
|
||||||
if(k < length(y)) {
|
if (k < length(y)) {
|
||||||
## reset levels so that the possible levels and
|
## reset levels so that the possible levels and
|
||||||
## the levels in the vector are the same
|
## the levels in the vector are the same
|
||||||
y <- factor(as.character(y))
|
y <- factor(as.character(y))
|
||||||
@ -245,19 +247,83 @@ xgb.createFolds <- function(y, k = 10)
|
|||||||
## For each class, balance the fold allocation as far
|
## For each class, balance the fold allocation as far
|
||||||
## as possible, then resample the remainder.
|
## as possible, then resample the remainder.
|
||||||
## The final assignment of folds is also randomized.
|
## The final assignment of folds is also randomized.
|
||||||
for(i in 1:length(numInClass)) {
|
for (i in 1:length(numInClass)) {
|
||||||
## create a vector of integers from 1:k as many times as possible without
|
## create a vector of integers from 1:k as many times as possible without
|
||||||
## going over the number of samples in the class. Note that if the number
|
## going over the number of samples in the class. Note that if the number
|
||||||
## of samples in a class is less than k, nothing is producd here.
|
## of samples in a class is less than k, nothing is producd here.
|
||||||
seqVector <- rep(1:k, numInClass[i] %/% k)
|
seqVector <- rep(1:k, numInClass[i] %/% k)
|
||||||
## add enough random integers to get length(seqVector) == numInClass[i]
|
## add enough random integers to get length(seqVector) == numInClass[i]
|
||||||
if(numInClass[i] %% k > 0) seqVector <- c(seqVector, sample(1:k, numInClass[i] %% k))
|
if (numInClass[i] %% k > 0) seqVector <- c(seqVector, sample(1:k, numInClass[i] %% k))
|
||||||
## shuffle the integers for fold assignment and assign to this classes's data
|
## shuffle the integers for fold assignment and assign to this classes's data
|
||||||
foldVector[which(y == dimnames(numInClass)$y[i])] <- sample(seqVector)
|
foldVector[which(y == dimnames(numInClass)$y[i])] <- sample(seqVector)
|
||||||
}
|
}
|
||||||
} else foldVector <- seq(along = y)
|
} else {
|
||||||
|
foldVector <- seq(along = y)
|
||||||
|
}
|
||||||
|
|
||||||
out <- split(seq(along = y), foldVector)
|
out <- split(seq(along = y), foldVector)
|
||||||
names(out) <- NULL
|
names(out) <- NULL
|
||||||
out
|
out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# Deprectaion notice utilities ------------------------------------------------
|
||||||
|
#
|
||||||
|
|
||||||
|
#' Deprecation notices.
|
||||||
|
#'
|
||||||
|
#' At this time, some of the parameter names were changed in order to make the code style more uniform.
|
||||||
|
#' The deprecated parameters would be removed in the next release.
|
||||||
|
#'
|
||||||
|
#' To see all the current deprecated and new parameters, check the \code{xgboost:::depr_par_lut} table.
|
||||||
|
#'
|
||||||
|
#' A deprecation warning is shown when any of the deprecated parameters is used in a call.
|
||||||
|
#' An additional warning is shown when there was a partial match to a deprecated parameter
|
||||||
|
#' (as R is able to partially match parameter names).
|
||||||
|
#'
|
||||||
|
#' @name xgboost-deprecated
|
||||||
|
NULL
|
||||||
|
|
||||||
|
# Lookup table for the deprecated parameters bookkeeping
|
||||||
|
depr_par_lut <- matrix(c(
|
||||||
|
'print.every.n', 'print_every_n',
|
||||||
|
'early.stop.round', 'early_stopping_rounds',
|
||||||
|
'training.data', 'data',
|
||||||
|
'with.stats', 'with_stats',
|
||||||
|
'numberOfClusters', 'n_clusters',
|
||||||
|
'features.keep', 'features_keep',
|
||||||
|
'plot.height','plot_height',
|
||||||
|
'plot.width','plot_width',
|
||||||
|
'dummy', 'DUMMY'
|
||||||
|
), ncol=2, byrow = TRUE)
|
||||||
|
colnames(depr_par_lut) <- c('old', 'new')
|
||||||
|
|
||||||
|
# Checks the dot-parameters for deprecated names
|
||||||
|
# (including partial matching), gives a deprecation warning,
|
||||||
|
# and sets new parameters to the old parameters' values within its parent frame.
|
||||||
|
# WARNING: has side-effects
|
||||||
|
check.deprecation <- function(..., env = parent.frame()) {
|
||||||
|
pars <- list(...)
|
||||||
|
# exact and partial matches
|
||||||
|
all_match <- pmatch(names(pars), depr_par_lut[,1])
|
||||||
|
# indices of matched pars' names
|
||||||
|
idx_pars <- which(!is.na(all_match))
|
||||||
|
if (length(idx_pars) == 0) return()
|
||||||
|
# indices of matched LUT rows
|
||||||
|
idx_lut <- all_match[idx_pars]
|
||||||
|
# which of idx_lut were the exact matches?
|
||||||
|
ex_match <- depr_par_lut[idx_lut,1] %in% names(pars)
|
||||||
|
for (i in seq_along(idx_pars)) {
|
||||||
|
pars_par <- names(pars)[idx_pars[i]]
|
||||||
|
old_par <- depr_par_lut[idx_lut[i], 1]
|
||||||
|
new_par <- depr_par_lut[idx_lut[i], 2]
|
||||||
|
if (!ex_match[i]) {
|
||||||
|
warning("'", pars_par, "' was partially matched to '", old_par,"'")
|
||||||
|
}
|
||||||
|
.Deprecated(new_par, old=old_par, package = 'xgboost')
|
||||||
|
if (new_par != 'NULL') {
|
||||||
|
eval(parse(text = paste(new_par, '<-', pars[[pars_par]])), envir = env)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user