[R] Support booster slicing. (#9948)
This commit is contained in:
@@ -693,16 +693,94 @@ setinfo.xgb.Booster <- function(object, name, info) {
|
||||
}
|
||||
|
||||
#' @title Get number of boosting in a fitted booster
|
||||
#' @param model A fitted `xgb.Booster` model.
|
||||
#' @param model,x A fitted `xgb.Booster` model.
|
||||
#' @return The number of rounds saved in the model, as an integer.
|
||||
#' @details Note that setting booster parameters related to training
|
||||
#' continuation / updates through \link{xgb.parameters<-} will reset the
|
||||
#' number of rounds to zero.
|
||||
#' @export
|
||||
#' @rdname xgb.get.num.boosted.rounds
|
||||
xgb.get.num.boosted.rounds <- function(model) {
|
||||
return(.Call(XGBoosterBoostedRounds_R, xgb.get.handle(model)))
|
||||
}
|
||||
|
||||
#' @rdname xgb.get.num.boosted.rounds
|
||||
#' @export
|
||||
length.xgb.Booster <- function(x) {
|
||||
return(xgb.get.num.boosted.rounds(x))
|
||||
}
|
||||
|
||||
#' @title Slice Booster by Rounds
|
||||
#' @description Creates a new booster including only a selected range of rounds / iterations
|
||||
#' from an existing booster, as given by the sequence `seq(start, end, step)`.
|
||||
#' @details Note that any R attributes that the booster might have, will not be copied into
|
||||
#' the resulting object.
|
||||
#' @param model,x A fitted `xgb.Booster` object, which is to be sliced by taking only a subset
|
||||
#' of its rounds / iterations.
|
||||
#' @param start Start of the slice (base-1 and inclusive, like R's \link{seq}).
|
||||
#' @param end End of the slice (base-1 and inclusive, like R's \link{seq}).
|
||||
#'
|
||||
#' Passing a value of zero here is equivalent to passing the full number of rounds in the
|
||||
#' booster object.
|
||||
#' @param step Step size of the slice. Passing '1' will take every round in the sequence defined by
|
||||
#' `(start, end)`, while passing '2' will take every second value, and so on.
|
||||
#' @return A sliced booster object containing only the requested rounds.
|
||||
#' @examples
|
||||
#' data(mtcars)
|
||||
#' y <- mtcars$mpg
|
||||
#' x <- as.matrix(mtcars[, -1])
|
||||
#' dm <- xgb.DMatrix(x, label = y, nthread = 1)
|
||||
#' model <- xgb.train(data = dm, params = list(nthread = 1), nrounds = 5)
|
||||
#' model_slice <- xgb.slice.Booster(model, 1, 3)
|
||||
#' # Prediction for first three rounds
|
||||
#' predict(model, x, predleaf = TRUE)[, 1:3]
|
||||
#'
|
||||
#' # The new model has only those rounds, so
|
||||
#' # a full prediction from it is equivalent
|
||||
#' predict(model_slice, x, predleaf = TRUE)
|
||||
#' @export
|
||||
#' @rdname xgb.slice.Booster
|
||||
xgb.slice.Booster <- function(model, start, end = xgb.get.num.boosted.rounds(model), step = 1L) {
|
||||
# This makes the slice mimic the behavior of R's 'seq',
|
||||
# which truncates on the end of the slice when the step
|
||||
# doesn't reach it.
|
||||
if (end > start && step > 1) {
|
||||
d <- (end - start + 1) / step
|
||||
if (d != floor(d)) {
|
||||
end <- start + step * ceiling(d) - 1
|
||||
}
|
||||
}
|
||||
return(
|
||||
.Call(
|
||||
XGBoosterSlice_R,
|
||||
xgb.get.handle(model),
|
||||
start - 1,
|
||||
end,
|
||||
step
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
#' @export
|
||||
#' @rdname xgb.slice.Booster
|
||||
#' @param i The indices - must be an increasing sequence as generated by e.g. `seq(...)`.
|
||||
`[.xgb.Booster` <- function(x, i) {
|
||||
if (missing(i)) {
|
||||
return(xgb.slice.Booster(x, 1, 0))
|
||||
}
|
||||
if (length(i) == 1) {
|
||||
return(xgb.slice.Booster(x, i, i))
|
||||
}
|
||||
steps <- diff(i)
|
||||
if (any(steps < 0)) {
|
||||
stop("Can only slice booster with ascending sequences.")
|
||||
}
|
||||
if (length(unique(steps)) > 1) {
|
||||
stop("Can only slice booster with fixed-step sequences.")
|
||||
}
|
||||
return(xgb.slice.Booster(x, i[1L], i[length(i)], steps[1L]))
|
||||
}
|
||||
|
||||
#' @title Get Features Names from Booster
|
||||
#' @description Returns the feature / variable / column names from a fitted
|
||||
#' booster object, which are set automatically during the call to \link{xgb.train}
|
||||
|
||||
Reference in New Issue
Block a user