[R] Support booster slicing. (#9948)

This commit is contained in:
david-cortes
2024-01-20 22:11:26 +01:00
committed by GitHub
parent c5d0608057
commit 5062a3ab46
8 changed files with 237 additions and 2 deletions

View File

@@ -693,16 +693,94 @@ setinfo.xgb.Booster <- function(object, name, info) {
}
#' @title Get number of boosting in a fitted booster
#' @param model A fitted `xgb.Booster` model.
#' @param model,x A fitted `xgb.Booster` model.
#' @return The number of rounds saved in the model, as an integer.
#' @details Note that setting booster parameters related to training
#' continuation / updates through \link{xgb.parameters<-} will reset the
#' number of rounds to zero.
#' @export
#' @rdname xgb.get.num.boosted.rounds
xgb.get.num.boosted.rounds <- function(model) {
return(.Call(XGBoosterBoostedRounds_R, xgb.get.handle(model)))
}
#' @rdname xgb.get.num.boosted.rounds
#' @export
length.xgb.Booster <- function(x) {
return(xgb.get.num.boosted.rounds(x))
}
#' @title Slice Booster by Rounds
#' @description Creates a new booster including only a selected range of rounds / iterations
#' from an existing booster, as given by the sequence `seq(start, end, step)`.
#' @details Note that any R attributes that the booster might have, will not be copied into
#' the resulting object.
#' @param model,x A fitted `xgb.Booster` object, which is to be sliced by taking only a subset
#' of its rounds / iterations.
#' @param start Start of the slice (base-1 and inclusive, like R's \link{seq}).
#' @param end End of the slice (base-1 and inclusive, like R's \link{seq}).
#'
#' Passing a value of zero here is equivalent to passing the full number of rounds in the
#' booster object.
#' @param step Step size of the slice. Passing '1' will take every round in the sequence defined by
#' `(start, end)`, while passing '2' will take every second value, and so on.
#' @return A sliced booster object containing only the requested rounds.
#' @examples
#' data(mtcars)
#' y <- mtcars$mpg
#' x <- as.matrix(mtcars[, -1])
#' dm <- xgb.DMatrix(x, label = y, nthread = 1)
#' model <- xgb.train(data = dm, params = list(nthread = 1), nrounds = 5)
#' model_slice <- xgb.slice.Booster(model, 1, 3)
#' # Prediction for first three rounds
#' predict(model, x, predleaf = TRUE)[, 1:3]
#'
#' # The new model has only those rounds, so
#' # a full prediction from it is equivalent
#' predict(model_slice, x, predleaf = TRUE)
#' @export
#' @rdname xgb.slice.Booster
xgb.slice.Booster <- function(model, start, end = xgb.get.num.boosted.rounds(model), step = 1L) {
# This makes the slice mimic the behavior of R's 'seq',
# which truncates on the end of the slice when the step
# doesn't reach it.
if (end > start && step > 1) {
d <- (end - start + 1) / step
if (d != floor(d)) {
end <- start + step * ceiling(d) - 1
}
}
return(
.Call(
XGBoosterSlice_R,
xgb.get.handle(model),
start - 1,
end,
step
)
)
}
#' @export
#' @rdname xgb.slice.Booster
#' @param i The indices - must be an increasing sequence as generated by e.g. `seq(...)`.
`[.xgb.Booster` <- function(x, i) {
if (missing(i)) {
return(xgb.slice.Booster(x, 1, 0))
}
if (length(i) == 1) {
return(xgb.slice.Booster(x, i, i))
}
steps <- diff(i)
if (any(steps < 0)) {
stop("Can only slice booster with ascending sequences.")
}
if (length(unique(steps)) > 1) {
stop("Can only slice booster with fixed-step sequences.")
}
return(xgb.slice.Booster(x, i[1L], i[length(i)], steps[1L]))
}
#' @title Get Features Names from Booster
#' @description Returns the feature / variable / column names from a fitted
#' booster object, which are set automatically during the call to \link{xgb.train}