[R] Remove parameters and attributes related to ntree and rebase iterationrange (#9935)
This commit is contained in:
parent
60b9d2eeb9
commit
c5d0608057
@ -280,7 +280,6 @@ cb.reset.parameters <- function(new_params) {
|
||||
#' \code{iteration},
|
||||
#' \code{begin_iteration},
|
||||
#' \code{end_iteration},
|
||||
#' \code{num_parallel_tree}.
|
||||
#'
|
||||
#' @seealso
|
||||
#' \code{\link{callbacks}},
|
||||
@ -291,7 +290,6 @@ cb.early.stop <- function(stopping_rounds, maximize = FALSE,
|
||||
metric_name = NULL, verbose = TRUE) {
|
||||
# state variables
|
||||
best_iteration <- -1
|
||||
best_ntreelimit <- -1
|
||||
best_score <- Inf
|
||||
best_msg <- NULL
|
||||
metric_idx <- 1
|
||||
@ -358,12 +356,10 @@ cb.early.stop <- function(stopping_rounds, maximize = FALSE,
|
||||
# If the difference is due to floating-point truncation, update best_score
|
||||
best_score <- attr_best_score
|
||||
}
|
||||
xgb.attr(env$bst, "best_iteration") <- best_iteration
|
||||
xgb.attr(env$bst, "best_ntreelimit") <- best_ntreelimit
|
||||
xgb.attr(env$bst, "best_iteration") <- best_iteration - 1
|
||||
xgb.attr(env$bst, "best_score") <- best_score
|
||||
} else {
|
||||
env$basket$best_iteration <- best_iteration
|
||||
env$basket$best_ntreelimit <- best_ntreelimit
|
||||
}
|
||||
}
|
||||
|
||||
@ -385,14 +381,13 @@ cb.early.stop <- function(stopping_rounds, maximize = FALSE,
|
||||
)
|
||||
best_score <<- score
|
||||
best_iteration <<- i
|
||||
best_ntreelimit <<- best_iteration * env$num_parallel_tree
|
||||
# save the property to attributes, so they will occur in checkpoint
|
||||
if (!is.null(env$bst)) {
|
||||
xgb.attributes(env$bst) <- list(
|
||||
best_iteration = best_iteration - 1, # convert to 0-based index
|
||||
best_score = best_score,
|
||||
best_msg = best_msg,
|
||||
best_ntreelimit = best_ntreelimit)
|
||||
best_msg = best_msg
|
||||
)
|
||||
}
|
||||
} else if (i - best_iteration >= stopping_rounds) {
|
||||
env$stop_condition <- TRUE
|
||||
@ -475,8 +470,6 @@ cb.save.model <- function(save_period = 0, save_name = "xgboost.ubj") {
|
||||
#' \code{data},
|
||||
#' \code{end_iteration},
|
||||
#' \code{params},
|
||||
#' \code{num_parallel_tree},
|
||||
#' \code{num_class}.
|
||||
#'
|
||||
#' @return
|
||||
#' Predictions are returned inside of the \code{pred} element, which is either a vector or a matrix,
|
||||
@ -499,19 +492,21 @@ cb.cv.predict <- function(save_models = FALSE) {
|
||||
stop("'cb.cv.predict' callback requires 'basket' and 'bst_folds' lists in its calling frame")
|
||||
|
||||
N <- nrow(env$data)
|
||||
pred <-
|
||||
if (env$num_class > 1) {
|
||||
matrix(NA_real_, N, env$num_class)
|
||||
} else {
|
||||
rep(NA_real_, N)
|
||||
}
|
||||
pred <- NULL
|
||||
|
||||
iterationrange <- c(1, NVL(env$basket$best_iteration, env$end_iteration) + 1)
|
||||
iterationrange <- c(1, NVL(env$basket$best_iteration, env$end_iteration))
|
||||
if (NVL(env$params[['booster']], '') == 'gblinear') {
|
||||
iterationrange <- c(1, 1) # must be 0 for gblinear
|
||||
iterationrange <- "all"
|
||||
}
|
||||
for (fd in env$bst_folds) {
|
||||
pr <- predict(fd$bst, fd$watchlist[[2]], iterationrange = iterationrange, reshape = TRUE)
|
||||
if (is.null(pred)) {
|
||||
if (NCOL(pr) > 1L) {
|
||||
pred <- matrix(NA_real_, N, ncol(pr))
|
||||
} else {
|
||||
pred <- matrix(NA_real_, N)
|
||||
}
|
||||
}
|
||||
if (is.matrix(pred)) {
|
||||
pred[fd$index, ] <- pr
|
||||
} else {
|
||||
|
||||
@ -208,7 +208,7 @@ xgb.iter.eval <- function(bst, watchlist, iter, feval) {
|
||||
res <- sapply(seq_along(watchlist), function(j) {
|
||||
w <- watchlist[[j]]
|
||||
## predict using all trees
|
||||
preds <- predict(bst, w, outputmargin = TRUE, iterationrange = c(1, 1))
|
||||
preds <- predict(bst, w, outputmargin = TRUE, iterationrange = "all")
|
||||
eval_res <- feval(preds, w)
|
||||
out <- eval_res$value
|
||||
names(out) <- paste0(evnames[j], "-", eval_res$metric)
|
||||
|
||||
@ -89,7 +89,6 @@ xgb.get.handle <- function(object) {
|
||||
#' @param outputmargin Whether the prediction should be returned in the form of original untransformed
|
||||
#' sum of predictions from boosting iterations' results. E.g., setting `outputmargin=TRUE` for
|
||||
#' logistic regression would return log-odds instead of probabilities.
|
||||
#' @param ntreelimit Deprecated, use `iterationrange` instead.
|
||||
#' @param predleaf Whether to predict pre-tree leaf indices.
|
||||
#' @param predcontrib Whether to return feature contributions to individual predictions (see Details).
|
||||
#' @param approxcontrib Whether to use a fast approximation for feature contributions (see Details).
|
||||
@ -99,11 +98,17 @@ xgb.get.handle <- function(object) {
|
||||
#' or `predinteraction` is `TRUE`.
|
||||
#' @param training Whether the predictions are used for training. For dart booster,
|
||||
#' training predicting will perform dropout.
|
||||
#' @param iterationrange Specifies which trees are used in prediction. For
|
||||
#' example, take a random forest with 100 rounds.
|
||||
#' With `iterationrange=c(1, 21)`, only the trees built during `[1, 21)` (half open set)
|
||||
#' rounds are used in this prediction. The index is 1-based just like an R vector. When set
|
||||
#' to `c(1, 1)`, XGBoost will use all trees.
|
||||
#' @param iterationrange Sequence of rounds/iterations from the model to use for prediction, specified by passing
|
||||
#' a two-dimensional vector with the start and end numbers in the sequence (same format as R's `seq` - i.e.
|
||||
#' base-1 indexing, and inclusive of both ends).
|
||||
#'
|
||||
#' For example, passing `c(1,20)` will predict using the first twenty iterations, while passing `c(1,1)` will
|
||||
#' predict using only the first one.
|
||||
#'
|
||||
#' If passing `NULL`, will either stop at the best iteration if the model used early stopping, or use all
|
||||
#' of the iterations (rounds) otherwise.
|
||||
#'
|
||||
#' If passing "all", will use all of the rounds regardless of whether the model had early stopping or not.
|
||||
#' @param strict_shape Default is `FALSE`. When set to `TRUE`, the output
|
||||
#' type and shape of predictions are invariant to the model type.
|
||||
#' @param ... Not used.
|
||||
@ -189,7 +194,7 @@ xgb.get.handle <- function(object) {
|
||||
#' # use all trees by default
|
||||
#' pred <- predict(bst, test$data)
|
||||
#' # use only the 1st tree
|
||||
#' pred1 <- predict(bst, test$data, iterationrange = c(1, 2))
|
||||
#' pred1 <- predict(bst, test$data, iterationrange = c(1, 1))
|
||||
#'
|
||||
#' # Predicting tree leafs:
|
||||
#' # the result is an nsamples X ntrees matrix
|
||||
@ -260,11 +265,11 @@ xgb.get.handle <- function(object) {
|
||||
#' all.equal(pred, pred_labels)
|
||||
#' # prediction from using only 5 iterations should result
|
||||
#' # in the same error as seen in iteration 5:
|
||||
#' pred5 <- predict(bst, as.matrix(iris[, -5]), iterationrange = c(1, 6))
|
||||
#' pred5 <- predict(bst, as.matrix(iris[, -5]), iterationrange = c(1, 5))
|
||||
#' sum(pred5 != lb) / length(lb)
|
||||
#'
|
||||
#' @export
|
||||
predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FALSE, ntreelimit = NULL,
|
||||
predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FALSE,
|
||||
predleaf = FALSE, predcontrib = FALSE, approxcontrib = FALSE, predinteraction = FALSE,
|
||||
reshape = FALSE, training = FALSE, iterationrange = NULL, strict_shape = FALSE, ...) {
|
||||
if (!inherits(newdata, "xgb.DMatrix")) {
|
||||
@ -275,25 +280,21 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
|
||||
)
|
||||
}
|
||||
|
||||
if (NVL(xgb.booster_type(object), '') == 'gblinear' || is.null(ntreelimit))
|
||||
ntreelimit <- 0
|
||||
|
||||
if (ntreelimit != 0 && is.null(iterationrange)) {
|
||||
## only ntreelimit, initialize iteration range
|
||||
iterationrange <- c(0, 0)
|
||||
} else if (ntreelimit == 0 && !is.null(iterationrange)) {
|
||||
## only iteration range, handle 1-based indexing
|
||||
iterationrange <- c(iterationrange[1] - 1, iterationrange[2] - 1)
|
||||
} else if (ntreelimit != 0 && !is.null(iterationrange)) {
|
||||
## both are specified, let libgxgboost throw an error
|
||||
if (!is.null(iterationrange)) {
|
||||
if (is.character(iterationrange)) {
|
||||
stopifnot(iterationrange == "all")
|
||||
iterationrange <- c(0, 0)
|
||||
} else {
|
||||
iterationrange[1] <- iterationrange[1] - 1 # base-0 indexing
|
||||
}
|
||||
} else {
|
||||
## no limit is supplied, use best
|
||||
best_iteration <- xgb.best_iteration(object)
|
||||
if (is.null(best_iteration)) {
|
||||
iterationrange <- c(0, 0)
|
||||
} else {
|
||||
## We don't need to + 1 as R is 1-based index.
|
||||
iterationrange <- c(0, as.integer(best_iteration))
|
||||
iterationrange <- c(0, as.integer(best_iteration) + 1L)
|
||||
}
|
||||
}
|
||||
## Handle the 0 length values.
|
||||
@ -312,7 +313,6 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
|
||||
strict_shape = box(TRUE),
|
||||
iteration_begin = box(as.integer(iterationrange[1])),
|
||||
iteration_end = box(as.integer(iterationrange[2])),
|
||||
ntree_limit = box(as.integer(ntreelimit)),
|
||||
type = box(as.integer(0))
|
||||
)
|
||||
|
||||
@ -500,7 +500,7 @@ xgb.attr <- function(object, name) {
|
||||
return(NULL)
|
||||
}
|
||||
if (!is.null(out)) {
|
||||
if (name %in% c("best_iteration", "best_ntreelimit", "best_score")) {
|
||||
if (name %in% c("best_iteration", "best_score")) {
|
||||
out <- as.numeric(out)
|
||||
}
|
||||
}
|
||||
@ -718,12 +718,6 @@ variable.names.xgb.Booster <- function(object, ...) {
|
||||
return(getinfo(object, "feature_name"))
|
||||
}
|
||||
|
||||
xgb.ntree <- function(bst) {
|
||||
config <- xgb.config(bst)
|
||||
out <- strtoi(config$learner$gradient_booster$gbtree_model_param$num_trees)
|
||||
return(out)
|
||||
}
|
||||
|
||||
xgb.nthread <- function(bst) {
|
||||
config <- xgb.config(bst)
|
||||
out <- strtoi(config$learner$generic_param$nthread)
|
||||
|
||||
@ -103,7 +103,6 @@
|
||||
#' parameter or randomly generated.
|
||||
#' \item \code{best_iteration} iteration number with the best evaluation metric value
|
||||
#' (only available with early stopping).
|
||||
#' \item \code{best_ntreelimit} and the \code{ntreelimit} Deprecated attributes, use \code{best_iteration} instead.
|
||||
#' \item \code{pred} CV prediction values available when \code{prediction} is set.
|
||||
#' It is either vector or matrix (see \code{\link{cb.cv.predict}}).
|
||||
#' \item \code{models} a list of the CV folds' models. It is only available with the explicit
|
||||
@ -218,7 +217,6 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
|
||||
|
||||
# extract parameters that can affect the relationship b/w #trees and #iterations
|
||||
num_class <- max(as.numeric(NVL(params[['num_class']], 1)), 1) # nolint
|
||||
num_parallel_tree <- max(as.numeric(NVL(params[['num_parallel_tree']], 1)), 1) # nolint
|
||||
|
||||
# those are fixed for CV (no training continuation)
|
||||
begin_iteration <- 1
|
||||
@ -318,7 +316,7 @@ print.xgb.cv.synchronous <- function(x, verbose = FALSE, ...) {
|
||||
})
|
||||
}
|
||||
|
||||
for (n in c('niter', 'best_iteration', 'best_ntreelimit')) {
|
||||
for (n in c('niter', 'best_iteration')) {
|
||||
if (is.null(x[[n]]))
|
||||
next
|
||||
cat(n, ': ', x[[n]], '\n', sep = '')
|
||||
|
||||
@ -393,7 +393,6 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(),
|
||||
# Note: it might look like these aren't used, but they need to be defined in this
|
||||
# environment for the callbacks for work correctly.
|
||||
num_class <- max(as.numeric(NVL(params[['num_class']], 1)), 1) # nolint
|
||||
num_parallel_tree <- max(as.numeric(NVL(params[['num_parallel_tree']], 1)), 1) # nolint
|
||||
|
||||
if (is_update && nrounds > niter_init)
|
||||
stop("nrounds cannot be larger than ", niter_init, " (nrounds of xgb_model)")
|
||||
|
||||
@ -15,7 +15,7 @@ cat('start testing prediction from first n trees\n')
|
||||
labels <- getinfo(dtest, 'label')
|
||||
|
||||
### predict using first 1 tree
|
||||
ypred1 <- predict(bst, dtest, ntreelimit = 1)
|
||||
ypred1 <- predict(bst, dtest, iterationrange = c(1, 1))
|
||||
# by default, we predict using all the trees
|
||||
ypred2 <- predict(bst, dtest)
|
||||
|
||||
|
||||
@ -35,8 +35,6 @@ Callback function expects the following values to be set in its calling frame:
|
||||
\code{data},
|
||||
\code{end_iteration},
|
||||
\code{params},
|
||||
\code{num_parallel_tree},
|
||||
\code{num_class}.
|
||||
}
|
||||
\seealso{
|
||||
\code{\link{callbacks}}
|
||||
|
||||
@ -55,7 +55,6 @@ Callback function expects the following values to be set in its calling frame:
|
||||
\code{iteration},
|
||||
\code{begin_iteration},
|
||||
\code{end_iteration},
|
||||
\code{num_parallel_tree}.
|
||||
}
|
||||
\seealso{
|
||||
\code{\link{callbacks}},
|
||||
|
||||
@ -9,7 +9,6 @@
|
||||
newdata,
|
||||
missing = NA,
|
||||
outputmargin = FALSE,
|
||||
ntreelimit = NULL,
|
||||
predleaf = FALSE,
|
||||
predcontrib = FALSE,
|
||||
approxcontrib = FALSE,
|
||||
@ -36,8 +35,6 @@ missing values in data (e.g., 0 or some other extreme value).}
|
||||
sum of predictions from boosting iterations' results. E.g., setting \code{outputmargin=TRUE} for
|
||||
logistic regression would return log-odds instead of probabilities.}
|
||||
|
||||
\item{ntreelimit}{Deprecated, use \code{iterationrange} instead.}
|
||||
|
||||
\item{predleaf}{Whether to predict pre-tree leaf indices.}
|
||||
|
||||
\item{predcontrib}{Whether to return feature contributions to individual predictions (see Details).}
|
||||
@ -53,11 +50,18 @@ or \code{predinteraction} is \code{TRUE}.}
|
||||
\item{training}{Whether the predictions are used for training. For dart booster,
|
||||
training predicting will perform dropout.}
|
||||
|
||||
\item{iterationrange}{Specifies which trees are used in prediction. For
|
||||
example, take a random forest with 100 rounds.
|
||||
With \code{iterationrange=c(1, 21)}, only the trees built during \verb{[1, 21)} (half open set)
|
||||
rounds are used in this prediction. The index is 1-based just like an R vector. When set
|
||||
to \code{c(1, 1)}, XGBoost will use all trees.}
|
||||
\item{iterationrange}{Sequence of rounds/iterations from the model to use for prediction, specified by passing
|
||||
a two-dimensional vector with the start and end numbers in the sequence (same format as R's \code{seq} - i.e.
|
||||
base-1 indexing, and inclusive of both ends).
|
||||
|
||||
\if{html}{\out{<div class="sourceCode">}}\preformatted{ For example, passing `c(1,20)` will predict using the first twenty iterations, while passing `c(1,1)` will
|
||||
predict using only the first one.
|
||||
|
||||
If passing `NULL`, will either stop at the best iteration if the model used early stopping, or use all
|
||||
of the iterations (rounds) otherwise.
|
||||
|
||||
If passing "all", will use all of the rounds regardless of whether the model had early stopping or not.
|
||||
}\if{html}{\out{</div>}}}
|
||||
|
||||
\item{strict_shape}{Default is \code{FALSE}. When set to \code{TRUE}, the output
|
||||
type and shape of predictions are invariant to the model type.}
|
||||
@ -145,7 +149,7 @@ bst <- xgb.train(
|
||||
# use all trees by default
|
||||
pred <- predict(bst, test$data)
|
||||
# use only the 1st tree
|
||||
pred1 <- predict(bst, test$data, iterationrange = c(1, 2))
|
||||
pred1 <- predict(bst, test$data, iterationrange = c(1, 1))
|
||||
|
||||
# Predicting tree leafs:
|
||||
# the result is an nsamples X ntrees matrix
|
||||
@ -216,7 +220,7 @@ str(pred)
|
||||
all.equal(pred, pred_labels)
|
||||
# prediction from using only 5 iterations should result
|
||||
# in the same error as seen in iteration 5:
|
||||
pred5 <- predict(bst, as.matrix(iris[, -5]), iterationrange = c(1, 6))
|
||||
pred5 <- predict(bst, as.matrix(iris[, -5]), iterationrange = c(1, 5))
|
||||
sum(pred5 != lb) / length(lb)
|
||||
|
||||
}
|
||||
|
||||
@ -135,7 +135,6 @@ It is created by the \code{\link{cb.evaluation.log}} callback.
|
||||
parameter or randomly generated.
|
||||
\item \code{best_iteration} iteration number with the best evaluation metric value
|
||||
(only available with early stopping).
|
||||
\item \code{best_ntreelimit} and the \code{ntreelimit} Deprecated attributes, use \code{best_iteration} instead.
|
||||
\item \code{pred} CV prediction values available when \code{prediction} is set.
|
||||
It is either vector or matrix (see \code{\link{cb.cv.predict}}).
|
||||
\item \code{models} a list of the CV folds' models. It is only available with the explicit
|
||||
|
||||
@ -33,15 +33,11 @@ test_that("train and predict binary classification", {
|
||||
pred <- predict(bst, test$data)
|
||||
expect_length(pred, 1611)
|
||||
|
||||
pred1 <- predict(bst, train$data, ntreelimit = 1)
|
||||
pred1 <- predict(bst, train$data, iterationrange = c(1, 1))
|
||||
expect_length(pred1, 6513)
|
||||
err_pred1 <- sum((pred1 > 0.5) != train$label) / length(train$label)
|
||||
err_log <- attributes(bst)$evaluation_log[1, train_error]
|
||||
expect_lt(abs(err_pred1 - err_log), 10e-6)
|
||||
|
||||
pred2 <- predict(bst, train$data, iterationrange = c(1, 2))
|
||||
expect_length(pred1, 6513)
|
||||
expect_equal(pred1, pred2)
|
||||
})
|
||||
|
||||
test_that("parameter validation works", {
|
||||
@ -117,8 +113,8 @@ test_that("dart prediction works", {
|
||||
nrounds = nrounds,
|
||||
objective = "reg:squarederror"
|
||||
)
|
||||
pred_by_xgboost_0 <- predict(booster_by_xgboost, newdata = d, ntreelimit = 0)
|
||||
pred_by_xgboost_1 <- predict(booster_by_xgboost, newdata = d, ntreelimit = nrounds)
|
||||
pred_by_xgboost_0 <- predict(booster_by_xgboost, newdata = d, iterationrange = NULL)
|
||||
pred_by_xgboost_1 <- predict(booster_by_xgboost, newdata = d, iterationrange = c(1, nrounds))
|
||||
expect_true(all(matrix(pred_by_xgboost_0, byrow = TRUE) == matrix(pred_by_xgboost_1, byrow = TRUE)))
|
||||
|
||||
pred_by_xgboost_2 <- predict(booster_by_xgboost, newdata = d, training = TRUE)
|
||||
@ -139,8 +135,8 @@ test_that("dart prediction works", {
|
||||
data = dtrain,
|
||||
nrounds = nrounds
|
||||
)
|
||||
pred_by_train_0 <- predict(booster_by_train, newdata = dtrain, ntreelimit = 0)
|
||||
pred_by_train_1 <- predict(booster_by_train, newdata = dtrain, ntreelimit = nrounds)
|
||||
pred_by_train_0 <- predict(booster_by_train, newdata = dtrain, iterationrange = NULL)
|
||||
pred_by_train_1 <- predict(booster_by_train, newdata = dtrain, iterationrange = c(1, nrounds))
|
||||
pred_by_train_2 <- predict(booster_by_train, newdata = dtrain, training = TRUE)
|
||||
|
||||
expect_true(all(matrix(pred_by_train_0, byrow = TRUE) == matrix(pred_by_xgboost_0, byrow = TRUE)))
|
||||
@ -162,7 +158,7 @@ test_that("train and predict softprob", {
|
||||
)
|
||||
expect_false(is.null(attributes(bst)$evaluation_log))
|
||||
expect_lt(attributes(bst)$evaluation_log[, min(train_merror)], 0.025)
|
||||
expect_equal(xgb.get.num.boosted.rounds(bst) * 3, xgb.ntree(bst))
|
||||
expect_equal(xgb.get.num.boosted.rounds(bst), 5)
|
||||
pred <- predict(bst, as.matrix(iris[, -5]))
|
||||
expect_length(pred, nrow(iris) * 3)
|
||||
# row sums add up to total probability of 1:
|
||||
@ -174,12 +170,12 @@ test_that("train and predict softprob", {
|
||||
err <- sum(pred_labels != lb) / length(lb)
|
||||
expect_equal(attributes(bst)$evaluation_log[5, train_merror], err, tolerance = 5e-6)
|
||||
# manually calculate error at the 1st iteration:
|
||||
mpred <- predict(bst, as.matrix(iris[, -5]), reshape = TRUE, ntreelimit = 1)
|
||||
mpred <- predict(bst, as.matrix(iris[, -5]), reshape = TRUE, iterationrange = c(1, 1))
|
||||
pred_labels <- max.col(mpred) - 1
|
||||
err <- sum(pred_labels != lb) / length(lb)
|
||||
expect_equal(attributes(bst)$evaluation_log[1, train_merror], err, tolerance = 5e-6)
|
||||
|
||||
mpred1 <- predict(bst, as.matrix(iris[, -5]), reshape = TRUE, iterationrange = c(1, 2))
|
||||
mpred1 <- predict(bst, as.matrix(iris[, -5]), reshape = TRUE, iterationrange = c(1, 1))
|
||||
expect_equal(mpred, mpred1)
|
||||
|
||||
d <- cbind(
|
||||
@ -213,7 +209,7 @@ test_that("train and predict softmax", {
|
||||
)
|
||||
expect_false(is.null(attributes(bst)$evaluation_log))
|
||||
expect_lt(attributes(bst)$evaluation_log[, min(train_merror)], 0.025)
|
||||
expect_equal(xgb.get.num.boosted.rounds(bst) * 3, xgb.ntree(bst))
|
||||
expect_equal(xgb.get.num.boosted.rounds(bst), 5)
|
||||
|
||||
pred <- predict(bst, as.matrix(iris[, -5]))
|
||||
expect_length(pred, nrow(iris))
|
||||
@ -233,19 +229,15 @@ test_that("train and predict RF", {
|
||||
watchlist = list(train = xgb.DMatrix(train$data, label = lb))
|
||||
)
|
||||
expect_equal(xgb.get.num.boosted.rounds(bst), 1)
|
||||
expect_equal(xgb.ntree(bst), 20)
|
||||
|
||||
pred <- predict(bst, train$data)
|
||||
pred_err <- sum((pred > 0.5) != lb) / length(lb)
|
||||
expect_lt(abs(attributes(bst)$evaluation_log[1, train_error] - pred_err), 10e-6)
|
||||
# expect_lt(pred_err, 0.03)
|
||||
|
||||
pred <- predict(bst, train$data, ntreelimit = 20)
|
||||
pred <- predict(bst, train$data, iterationrange = c(1, 1))
|
||||
pred_err_20 <- sum((pred > 0.5) != lb) / length(lb)
|
||||
expect_equal(pred_err_20, pred_err)
|
||||
|
||||
pred1 <- predict(bst, train$data, iterationrange = c(1, 2))
|
||||
expect_equal(pred, pred1)
|
||||
})
|
||||
|
||||
test_that("train and predict RF with softprob", {
|
||||
@ -261,7 +253,6 @@ test_that("train and predict RF with softprob", {
|
||||
watchlist = list(train = xgb.DMatrix(as.matrix(iris[, -5]), label = lb))
|
||||
)
|
||||
expect_equal(xgb.get.num.boosted.rounds(bst), 15)
|
||||
expect_equal(xgb.ntree(bst), 15 * 3 * 4)
|
||||
# predict for all iterations:
|
||||
pred <- predict(bst, as.matrix(iris[, -5]), reshape = TRUE)
|
||||
expect_equal(dim(pred), c(nrow(iris), 3))
|
||||
@ -269,7 +260,7 @@ test_that("train and predict RF with softprob", {
|
||||
err <- sum(pred_labels != lb) / length(lb)
|
||||
expect_equal(attributes(bst)$evaluation_log[nrounds, train_merror], err, tolerance = 5e-6)
|
||||
# predict for 7 iterations and adjust for 4 parallel trees per iteration
|
||||
pred <- predict(bst, as.matrix(iris[, -5]), reshape = TRUE, ntreelimit = 7 * 4)
|
||||
pred <- predict(bst, as.matrix(iris[, -5]), reshape = TRUE, iterationrange = c(1, 7))
|
||||
err <- sum((max.col(pred) - 1) != lb) / length(lb)
|
||||
expect_equal(attributes(bst)$evaluation_log[7, train_merror], err, tolerance = 5e-6)
|
||||
})
|
||||
|
||||
@ -211,12 +211,11 @@ test_that("early stopping xgb.train works", {
|
||||
, "Stopping. Best iteration")
|
||||
expect_false(is.null(xgb.attr(bst, "best_iteration")))
|
||||
expect_lt(xgb.attr(bst, "best_iteration"), 19)
|
||||
expect_equal(xgb.attr(bst, "best_iteration"), xgb.attr(bst, "best_ntreelimit"))
|
||||
|
||||
pred <- predict(bst, dtest)
|
||||
expect_equal(length(pred), 1611)
|
||||
err_pred <- err(ltest, pred)
|
||||
err_log <- attributes(bst)$evaluation_log[xgb.attr(bst, "best_iteration"), test_error]
|
||||
err_log <- attributes(bst)$evaluation_log[xgb.attr(bst, "best_iteration") + 1, test_error]
|
||||
expect_equal(err_log, err_pred, tolerance = 5e-6)
|
||||
|
||||
set.seed(11)
|
||||
@ -231,8 +230,7 @@ test_that("early stopping xgb.train works", {
|
||||
loaded <- xgb.load(fname)
|
||||
|
||||
expect_false(is.null(xgb.attr(loaded, "best_iteration")))
|
||||
expect_equal(xgb.attr(loaded, "best_iteration"), xgb.attr(bst, "best_ntreelimit"))
|
||||
expect_equal(xgb.attr(loaded, "best_ntreelimit"), xgb.attr(bst, "best_ntreelimit"))
|
||||
expect_equal(xgb.attr(loaded, "best_iteration"), xgb.attr(bst, "best_iteration"))
|
||||
})
|
||||
|
||||
test_that("early stopping using a specific metric works", {
|
||||
@ -245,12 +243,11 @@ test_that("early stopping using a specific metric works", {
|
||||
, "Stopping. Best iteration")
|
||||
expect_false(is.null(xgb.attr(bst, "best_iteration")))
|
||||
expect_lt(xgb.attr(bst, "best_iteration"), 19)
|
||||
expect_equal(xgb.attr(bst, "best_iteration"), xgb.attr(bst, "best_ntreelimit"))
|
||||
|
||||
pred <- predict(bst, dtest, ntreelimit = xgb.attr(bst, "best_ntreelimit"))
|
||||
pred <- predict(bst, dtest, iterationrange = c(1, xgb.attr(bst, "best_iteration") + 1))
|
||||
expect_equal(length(pred), 1611)
|
||||
logloss_pred <- sum(-ltest * log(pred) - (1 - ltest) * log(1 - pred)) / length(ltest)
|
||||
logloss_log <- attributes(bst)$evaluation_log[xgb.attr(bst, "best_iteration"), test_logloss]
|
||||
logloss_log <- attributes(bst)$evaluation_log[xgb.attr(bst, "best_iteration") + 1, test_logloss]
|
||||
expect_equal(logloss_log, logloss_pred, tolerance = 1e-5)
|
||||
})
|
||||
|
||||
@ -286,7 +283,6 @@ test_that("early stopping xgb.cv works", {
|
||||
, "Stopping. Best iteration")
|
||||
expect_false(is.null(cv$best_iteration))
|
||||
expect_lt(cv$best_iteration, 19)
|
||||
expect_equal(cv$best_iteration, cv$best_ntreelimit)
|
||||
# the best error is min error:
|
||||
expect_true(cv$evaluation_log[, test_error_mean[cv$best_iteration] == min(test_error_mean)])
|
||||
})
|
||||
@ -354,3 +350,44 @@ test_that("prediction in xgb.cv for softprob works", {
|
||||
expect_equal(dim(cv$pred), c(nrow(iris), 3))
|
||||
expect_lt(diff(range(rowSums(cv$pred))), 1e-6)
|
||||
})
|
||||
|
||||
test_that("prediction in xgb.cv works for multi-quantile", {
|
||||
data(mtcars)
|
||||
y <- mtcars$mpg
|
||||
x <- as.matrix(mtcars[, -1])
|
||||
dm <- xgb.DMatrix(x, label = y, nthread = 1)
|
||||
cv <- xgb.cv(
|
||||
data = dm,
|
||||
params = list(
|
||||
objective = "reg:quantileerror",
|
||||
quantile_alpha = c(0.1, 0.2, 0.5, 0.8, 0.9),
|
||||
nthread = 1
|
||||
),
|
||||
nrounds = 5,
|
||||
nfold = 3,
|
||||
prediction = TRUE,
|
||||
verbose = 0
|
||||
)
|
||||
expect_equal(dim(cv$pred), c(nrow(x), 5))
|
||||
})
|
||||
|
||||
test_that("prediction in xgb.cv works for multi-output", {
|
||||
data(mtcars)
|
||||
y <- mtcars$mpg
|
||||
x <- as.matrix(mtcars[, -1])
|
||||
dm <- xgb.DMatrix(x, label = cbind(y, -y), nthread = 1)
|
||||
cv <- xgb.cv(
|
||||
data = dm,
|
||||
params = list(
|
||||
tree_method = "hist",
|
||||
multi_strategy = "multi_output_tree",
|
||||
objective = "reg:squarederror",
|
||||
nthread = n_threads
|
||||
),
|
||||
nrounds = 5,
|
||||
nfold = 3,
|
||||
prediction = TRUE,
|
||||
verbose = 0
|
||||
)
|
||||
expect_equal(dim(cv$pred), c(nrow(x), 2))
|
||||
})
|
||||
|
||||
@ -72,10 +72,10 @@ test_that("gblinear early stopping works", {
|
||||
booster <- xgb.train(
|
||||
param, dtrain, n, list(eval = dtest, train = dtrain), early_stopping_rounds = es_round
|
||||
)
|
||||
expect_equal(xgb.attr(booster, "best_iteration"), 5)
|
||||
expect_equal(xgb.attr(booster, "best_iteration"), 4)
|
||||
predt_es <- predict(booster, dtrain)
|
||||
|
||||
n <- xgb.attr(booster, "best_iteration") + es_round
|
||||
n <- xgb.attr(booster, "best_iteration") + es_round + 1
|
||||
booster <- xgb.train(
|
||||
param, dtrain, n, list(eval = dtest, train = dtrain), early_stopping_rounds = es_round
|
||||
)
|
||||
|
||||
@ -44,7 +44,7 @@ test_that('Test ranking with weighted data', {
|
||||
expect_true(all(diff(attributes(bst)$evaluation_log$train_auc) >= 0))
|
||||
expect_true(all(diff(attributes(bst)$evaluation_log$train_aucpr) >= 0))
|
||||
for (i in 1:10) {
|
||||
pred <- predict(bst, newdata = dtrain, ntreelimit = i)
|
||||
pred <- predict(bst, newdata = dtrain, iterationrange = c(1, i))
|
||||
# is_sorted[i]: is i-th group correctly sorted by the ranking predictor?
|
||||
is_sorted <- lapply(seq(1, 20, by = 5),
|
||||
function(k) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user