[R-package] a few fixes for R (#1485)
* [R] fix #1465 * [R] add sanity check to fix #1434 * [R] some clean-ups for custom obj&eval; require maximize only for early stopping
This commit is contained in:
parent
b8e6551734
commit
bdfa8c0e09
@ -458,6 +458,7 @@ cb.save.model <- function(save_period = 0, save_name = "xgboost.model") {
|
|||||||
#' \code{basket},
|
#' \code{basket},
|
||||||
#' \code{data},
|
#' \code{data},
|
||||||
#' \code{end_iteration},
|
#' \code{end_iteration},
|
||||||
|
#' \code{params},
|
||||||
#' \code{num_parallel_tree},
|
#' \code{num_parallel_tree},
|
||||||
#' \code{num_class}.
|
#' \code{num_class}.
|
||||||
#'
|
#'
|
||||||
@ -491,6 +492,9 @@ cb.cv.predict <- function(save_models = FALSE) {
|
|||||||
|
|
||||||
ntreelimit <- NVL(env$basket$best_ntreelimit,
|
ntreelimit <- NVL(env$basket$best_ntreelimit,
|
||||||
env$end_iteration * env$num_parallel_tree)
|
env$end_iteration * env$num_parallel_tree)
|
||||||
|
if (NVL(env$params[['booster']], '') == 'gblinear') {
|
||||||
|
ntreelimit <- 0 # must be 0 for gblinear
|
||||||
|
}
|
||||||
for (fd in env$bst_folds) {
|
for (fd in env$bst_folds) {
|
||||||
pr <- predict(fd$bst, fd$watchlist[[2]], ntreelimit = ntreelimit, reshape = TRUE)
|
pr <- predict(fd$bst, fd$watchlist[[2]], ntreelimit = ntreelimit, reshape = TRUE)
|
||||||
if (is.matrix(pred)) {
|
if (is.matrix(pred)) {
|
||||||
|
|||||||
@ -17,7 +17,7 @@ NVL <- function(x, val) {
|
|||||||
}
|
}
|
||||||
if (typeof(x) == 'closure')
|
if (typeof(x) == 'closure')
|
||||||
return(x)
|
return(x)
|
||||||
stop('x of unsupported for NVL type')
|
stop("typeof(x) == ", typeof(x), " is not supported by NVL")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -42,15 +42,15 @@ check.booster.params <- function(params, ...) {
|
|||||||
stop("Same parameters in 'params' and in the call are not allowed. Please check your 'params' list.")
|
stop("Same parameters in 'params' and in the call are not allowed. Please check your 'params' list.")
|
||||||
params <- c(params, dot_params)
|
params <- c(params, dot_params)
|
||||||
|
|
||||||
# providing a parameter multiple times only makes sense for 'eval_metric'
|
# providing a parameter multiple times makes sense only for 'eval_metric'
|
||||||
name_freqs <- table(names(params))
|
name_freqs <- table(names(params))
|
||||||
multi_names <- setdiff(names(name_freqs[name_freqs > 1]), 'eval_metric')
|
multi_names <- setdiff(names(name_freqs[name_freqs > 1]), 'eval_metric')
|
||||||
if (length(multi_names) > 0) {
|
if (length(multi_names) > 0) {
|
||||||
warning("The following parameters were provided multiple times:\n\t",
|
warning("The following parameters were provided multiple times:\n\t",
|
||||||
paste(multi_names, collapse=', '), "\n Only the last value for each of them will be used.\n")
|
paste(multi_names, collapse=', '), "\n Only the last value for each of them will be used.\n")
|
||||||
# While xgboost itself would choose the last value for a multi-parameter,
|
# While xgboost internals would choose the last value for a multiple-times parameter,
|
||||||
# will do some clean-up here b/c multi-parameters could be used further in R code, and R would
|
# enforce it here in R as well (b/c multi-parameters might be used further in R code,
|
||||||
# pick the 1st (not the last) value when multiple elements with the same name are present in a list.
|
# and R takes the 1st value when multiple elements with the same name are present in a list).
|
||||||
for (n in multi_names) {
|
for (n in multi_names) {
|
||||||
del_idx <- which(n == names(params))
|
del_idx <- which(n == names(params))
|
||||||
del_idx <- del_idx[-length(del_idx)]
|
del_idx <- del_idx[-length(del_idx)]
|
||||||
@ -60,9 +60,9 @@ check.booster.params <- function(params, ...) {
|
|||||||
|
|
||||||
# for multiclass, expect num_class to be set
|
# for multiclass, expect num_class to be set
|
||||||
if (typeof(params[['objective']]) == "character" &&
|
if (typeof(params[['objective']]) == "character" &&
|
||||||
substr(NVL(params[['objective']], 'x'), 1, 6) == 'multi:') {
|
substr(NVL(params[['objective']], 'x'), 1, 6) == 'multi:' &&
|
||||||
if (as.numeric(NVL(params[['num_class']], 0)) < 2)
|
as.numeric(NVL(params[['num_class']], 0)) < 2) {
|
||||||
stop("'num_class' > 1 parameter must be set for multiclass classification")
|
stop("'num_class' > 1 parameter must be set for multiclass classification")
|
||||||
}
|
}
|
||||||
|
|
||||||
return(params)
|
return(params)
|
||||||
@ -82,9 +82,7 @@ check.custom.obj <- function(env = parent.frame()) {
|
|||||||
if (!is.null(env$params[['objective']]) &&
|
if (!is.null(env$params[['objective']]) &&
|
||||||
typeof(env$params$objective) == 'closure') {
|
typeof(env$params$objective) == 'closure') {
|
||||||
env$obj <- env$params$objective
|
env$obj <- env$params$objective
|
||||||
p <- env$params
|
env$params$objective <- NULL
|
||||||
p$objective <- NULL
|
|
||||||
env$params <- p
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -97,17 +95,19 @@ check.custom.eval <- function(env = parent.frame()) {
|
|||||||
if (!is.null(env$feval) && typeof(env$feval) != 'closure')
|
if (!is.null(env$feval) && typeof(env$feval) != 'closure')
|
||||||
stop("'feval' must be a function")
|
stop("'feval' must be a function")
|
||||||
|
|
||||||
if (!is.null(env$feval) && is.null(env$maximize))
|
|
||||||
stop("Please set 'maximize' to indicate whether the metric needs to be maximized or not")
|
|
||||||
|
|
||||||
# handle a situation when custom eval function was provided through params
|
# handle a situation when custom eval function was provided through params
|
||||||
if (!is.null(env$params[['eval_metric']]) &&
|
if (!is.null(env$params[['eval_metric']]) &&
|
||||||
typeof(env$params$eval_metric) == 'closure') {
|
typeof(env$params$eval_metric) == 'closure') {
|
||||||
env$feval <- env$params$eval_metric
|
env$feval <- env$params$eval_metric
|
||||||
p <- env$params
|
env$params$eval_metric <- NULL
|
||||||
p[ which(names(p) == 'eval_metric') ] <- NULL
|
|
||||||
env$params <- p
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# require maximize to be set when custom feval and early stopping are used together
|
||||||
|
if (!is.null(env$feval) &&
|
||||||
|
is.null(env$maximize) && (
|
||||||
|
!is.null(env$early_stopping_rounds) ||
|
||||||
|
has.callbacks(env$callbacks, 'cb.early.stop')))
|
||||||
|
stop("Please set 'maximize' to indicate whether the evaluation metric needs to be maximized or not")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -65,10 +65,15 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
|
|||||||
stop("n_first_tree: Has to be a numeric vector of size 1.")
|
stop("n_first_tree: Has to be a numeric vector of size 1.")
|
||||||
}
|
}
|
||||||
|
|
||||||
if(is.null(text)){
|
if (is.null(text)){
|
||||||
text <- xgb.dump(model = model, with_stats = T)
|
text <- xgb.dump(model = model, with_stats = T)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (length(text) < 2 ||
|
||||||
|
sum(stri_detect_regex(text, 'yes=(\\d+),no=(\\d+)')) < 1) {
|
||||||
|
stop("Non-tree model detected! This function can only be used with tree models.")
|
||||||
|
}
|
||||||
|
|
||||||
position <- which(!is.na(stri_match_first_regex(text, "booster")))
|
position <- which(!is.na(stri_match_first_regex(text, "booster")))
|
||||||
|
|
||||||
add.tree.id <- function(x, i) paste(i, x, sep = "-")
|
add.tree.id <- function(x, i) paste(i, x, sep = "-")
|
||||||
|
|||||||
@ -173,9 +173,10 @@
|
|||||||
#' watchlist <- list(eval = dtest, train = dtrain)
|
#' watchlist <- list(eval = dtest, train = dtrain)
|
||||||
#'
|
#'
|
||||||
#' ## A simple xgb.train example:
|
#' ## A simple xgb.train example:
|
||||||
#' param <- list(max_depth = 2, eta = 1, silent = 1,
|
#' param <- list(max_depth = 2, eta = 1, silent = 1, nthread = 2,
|
||||||
#' objective = "binary:logistic", eval_metric = "auc")
|
#' objective = "binary:logistic", eval_metric = "auc")
|
||||||
#' bst <- xgb.train(param, dtrain, nthread = 2, nrounds = 2, watchlist)
|
#' bst <- xgb.train(param, dtrain, nrounds = 2, watchlist)
|
||||||
|
#'
|
||||||
#'
|
#'
|
||||||
#' ## An xgb.train example where custom objective and evaluation metric are used:
|
#' ## An xgb.train example where custom objective and evaluation metric are used:
|
||||||
#' logregobj <- function(preds, dtrain) {
|
#' logregobj <- function(preds, dtrain) {
|
||||||
@ -190,16 +191,33 @@
|
|||||||
#' err <- as.numeric(sum(labels != (preds > 0)))/length(labels)
|
#' err <- as.numeric(sum(labels != (preds > 0)))/length(labels)
|
||||||
#' return(list(metric = "error", value = err))
|
#' return(list(metric = "error", value = err))
|
||||||
#' }
|
#' }
|
||||||
#' bst <- xgb.train(param, dtrain, nthread = 2, nrounds = 2, watchlist)
|
#'
|
||||||
|
#' # These functions could be used by passing them either:
|
||||||
|
#' # as 'objective' and 'eval_metric' parameters in the params list:
|
||||||
|
#' param <- list(max_depth = 2, eta = 1, silent = 1, nthread = 2,
|
||||||
|
#' objective = logregobj, eval_metric = evalerror)
|
||||||
|
#' bst <- xgb.train(param, dtrain, nrounds = 2, watchlist)
|
||||||
|
#'
|
||||||
|
#' # or through the ... arguments:
|
||||||
|
#' param <- list(max_depth = 2, eta = 1, silent = 1, nthread = 2)
|
||||||
|
#' bst <- xgb.train(param, dtrain, nrounds = 2, watchlist,
|
||||||
|
#' objective = logregobj, eval_metric = evalerror)
|
||||||
|
#'
|
||||||
|
#' # or as dedicated 'obj' and 'feval' parameters of xgb.train:
|
||||||
|
#' bst <- xgb.train(param, dtrain, nrounds = 2, watchlist,
|
||||||
|
#' obj = logregobj, feval = evalerror)
|
||||||
|
#'
|
||||||
#'
|
#'
|
||||||
#' ## An xgb.train example of using variable learning rates at each iteration:
|
#' ## An xgb.train example of using variable learning rates at each iteration:
|
||||||
|
#' param <- list(max_depth = 2, eta = 1, silent = 1, nthread = 2)
|
||||||
#' my_etas <- list(eta = c(0.5, 0.1))
|
#' my_etas <- list(eta = c(0.5, 0.1))
|
||||||
#' bst <- xgb.train(param, dtrain, nthread = 2, nrounds = 2, watchlist,
|
#' bst <- xgb.train(param, dtrain, nrounds = 2, watchlist,
|
||||||
#' callbacks = list(cb.reset.parameters(my_etas)))
|
#' callbacks = list(cb.reset.parameters(my_etas)))
|
||||||
#'
|
#'
|
||||||
|
#'
|
||||||
#' ## Explicit use of the cb.evaluation.log callback allows to run
|
#' ## Explicit use of the cb.evaluation.log callback allows to run
|
||||||
#' ## xgb.train silently but still store the evaluation results:
|
#' ## xgb.train silently but still store the evaluation results:
|
||||||
#' bst <- xgb.train(param, dtrain, nthread = 2, nrounds = 2, watchlist,
|
#' bst <- xgb.train(param, dtrain, nrounds = 2, watchlist,
|
||||||
#' verbose = 0, callbacks = list(cb.evaluation.log()))
|
#' verbose = 0, callbacks = list(cb.evaluation.log()))
|
||||||
#' print(bst$evaluation_log)
|
#' print(bst$evaluation_log)
|
||||||
#'
|
#'
|
||||||
|
|||||||
@ -34,6 +34,7 @@ Callback function expects the following values to be set in its calling frame:
|
|||||||
\code{basket},
|
\code{basket},
|
||||||
\code{data},
|
\code{data},
|
||||||
\code{end_iteration},
|
\code{end_iteration},
|
||||||
|
\code{params},
|
||||||
\code{num_parallel_tree},
|
\code{num_parallel_tree},
|
||||||
\code{num_class}.
|
\code{num_class}.
|
||||||
}
|
}
|
||||||
|
|||||||
@ -200,9 +200,10 @@ dtest <- xgb.DMatrix(agaricus.test$data, label = agaricus.test$label)
|
|||||||
watchlist <- list(eval = dtest, train = dtrain)
|
watchlist <- list(eval = dtest, train = dtrain)
|
||||||
|
|
||||||
## A simple xgb.train example:
|
## A simple xgb.train example:
|
||||||
param <- list(max_depth = 2, eta = 1, silent = 1,
|
param <- list(max_depth = 2, eta = 1, silent = 1, nthread = 2,
|
||||||
objective = "binary:logistic", eval_metric = "auc")
|
objective = "binary:logistic", eval_metric = "auc")
|
||||||
bst <- xgb.train(param, dtrain, nthread = 2, nrounds = 2, watchlist)
|
bst <- xgb.train(param, dtrain, nrounds = 2, watchlist)
|
||||||
|
|
||||||
|
|
||||||
## An xgb.train example where custom objective and evaluation metric are used:
|
## An xgb.train example where custom objective and evaluation metric are used:
|
||||||
logregobj <- function(preds, dtrain) {
|
logregobj <- function(preds, dtrain) {
|
||||||
@ -217,16 +218,33 @@ evalerror <- function(preds, dtrain) {
|
|||||||
err <- as.numeric(sum(labels != (preds > 0)))/length(labels)
|
err <- as.numeric(sum(labels != (preds > 0)))/length(labels)
|
||||||
return(list(metric = "error", value = err))
|
return(list(metric = "error", value = err))
|
||||||
}
|
}
|
||||||
bst <- xgb.train(param, dtrain, nthread = 2, nrounds = 2, watchlist)
|
|
||||||
|
# These functions could be used by passing them either:
|
||||||
|
# as 'objective' and 'eval_metric' parameters in the params list:
|
||||||
|
param <- list(max_depth = 2, eta = 1, silent = 1, nthread = 2,
|
||||||
|
objective = logregobj, eval_metric = evalerror)
|
||||||
|
bst <- xgb.train(param, dtrain, nrounds = 2, watchlist)
|
||||||
|
|
||||||
|
# or through the ... arguments:
|
||||||
|
param <- list(max_depth = 2, eta = 1, silent = 1, nthread = 2)
|
||||||
|
bst <- xgb.train(param, dtrain, nrounds = 2, watchlist,
|
||||||
|
objective = logregobj, eval_metric = evalerror)
|
||||||
|
|
||||||
|
# or as dedicated 'obj' and 'feval' parameters of xgb.train:
|
||||||
|
bst <- xgb.train(param, dtrain, nrounds = 2, watchlist,
|
||||||
|
obj = logregobj, feval = evalerror)
|
||||||
|
|
||||||
|
|
||||||
## An xgb.train example of using variable learning rates at each iteration:
|
## An xgb.train example of using variable learning rates at each iteration:
|
||||||
|
param <- list(max_depth = 2, eta = 1, silent = 1, nthread = 2)
|
||||||
my_etas <- list(eta = c(0.5, 0.1))
|
my_etas <- list(eta = c(0.5, 0.1))
|
||||||
bst <- xgb.train(param, dtrain, nthread = 2, nrounds = 2, watchlist,
|
bst <- xgb.train(param, dtrain, nrounds = 2, watchlist,
|
||||||
callbacks = list(cb.reset.parameters(my_etas)))
|
callbacks = list(cb.reset.parameters(my_etas)))
|
||||||
|
|
||||||
|
|
||||||
## Explicit use of the cb.evaluation.log callback allows to run
|
## Explicit use of the cb.evaluation.log callback allows to run
|
||||||
## xgb.train silently but still store the evaluation results:
|
## xgb.train silently but still store the evaluation results:
|
||||||
bst <- xgb.train(param, dtrain, nthread = 2, nrounds = 2, watchlist,
|
bst <- xgb.train(param, dtrain, nrounds = 2, watchlist,
|
||||||
verbose = 0, callbacks = list(cb.evaluation.log()))
|
verbose = 0, callbacks = list(cb.evaluation.log()))
|
||||||
print(bst$evaluation_log)
|
print(bst$evaluation_log)
|
||||||
|
|
||||||
|
|||||||
@ -260,6 +260,15 @@ test_that("prediction in xgb.cv works", {
|
|||||||
expect_true(all(sapply(cvx$models, class) == 'xgb.Booster'))
|
expect_true(all(sapply(cvx$models, class) == 'xgb.Booster'))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test_that("prediction in xgb.cv works for gblinear too", {
|
||||||
|
set.seed(11)
|
||||||
|
p <- list(booster = 'gblinear', objective = "reg:logistic", nthread = 2)
|
||||||
|
cv <- xgb.cv(p, dtrain, nfold = 5, eta = 0.5, nrounds = 2, prediction = TRUE)
|
||||||
|
expect_false(is.null(cv$evaluation_log))
|
||||||
|
expect_false(is.null(cv$pred))
|
||||||
|
expect_length(cv$pred, nrow(train$data))
|
||||||
|
})
|
||||||
|
|
||||||
test_that("prediction in early-stopping xgb.cv works", {
|
test_that("prediction in early-stopping xgb.cv works", {
|
||||||
set.seed(1)
|
set.seed(1)
|
||||||
expect_output(
|
expect_output(
|
||||||
|
|||||||
@ -81,6 +81,10 @@ test_that("xgb.model.dt.tree works with and without feature names", {
|
|||||||
expect_output(str(xgb.model.dt.tree(model = bst.Tree)), 'Feature.*\\"3\\"')
|
expect_output(str(xgb.model.dt.tree(model = bst.Tree)), 'Feature.*\\"3\\"')
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test_that("xgb.model.dt.tree throws error for gblinear", {
|
||||||
|
expect_error(xgb.model.dt.tree(model = bst.GLM))
|
||||||
|
})
|
||||||
|
|
||||||
test_that("xgb.importance works with and without feature names", {
|
test_that("xgb.importance works with and without feature names", {
|
||||||
importance.Tree <- xgb.importance(feature_names = feature.names, model = bst.Tree)
|
importance.Tree <- xgb.importance(feature_names = feature.names, model = bst.Tree)
|
||||||
expect_equal(dim(importance.Tree), c(7, 4))
|
expect_equal(dim(importance.Tree), c(7, 4))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user