[R] enable multi-dimensional base_margin (#9885)

This commit is contained in:
david-cortes 2023-12-14 02:16:53 +01:00 committed by GitHub
parent 936b22fdf3
commit cd473c9da3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 32 additions and 4 deletions

View File

@ -16,6 +16,8 @@
#' only care about the relative ordering of data points within each group,
#' so it doesn't make sense to assign weights to individual data points.
#' @param base_margin Base margin used for boosting from existing model.
#'
#' In the case of multi-output models, one can also pass multi-dimensional base_margin.
#' @param missing a float value to represents missing values in data (used only when input is a dense matrix).
#' It is useful when a 0 or some other extreme value represents missing values in data.
#' @param silent whether to suppress printing an informational message after loading from a file.
@ -439,9 +441,7 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) {
return(TRUE)
}
if (name == "base_margin") {
# if (length(info)!=nrow(object))
# stop("The length of base margin must equal to the number of rows in the input data")
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info))
.Call(XGDMatrixSetInfo_R, object, name, info)
return(TRUE)
}
if (name == "group") {

View File

@ -36,7 +36,9 @@ is assigned to each group (not each data point). This is because we
only care about the relative ordering of data points within each group,
so it doesn't make sense to assign weights to individual data points.}
\item{base_margin}{Base margin used for boosting from existing model.}
\item{base_margin}{Base margin used for boosting from existing model.
In the case of multi-output models, one can also pass multi-dimensional base_margin.}
\item{missing}{a float value to represents missing values in data (used only when input is a dense matrix).
It is useful when a 0 or some other extreme value represents missing values in data.}

View File

@ -349,3 +349,29 @@ test_that("xgb.DMatrix: data.frame", {
m <- xgb.DMatrix(df, enable_categorical = TRUE)
expect_equal(getinfo(m, "feature_type"), c("c", "c"))
})
test_that("xgb.DMatrix: can take multi-dimensional 'base_margin'", {
set.seed(123)
x <- matrix(rnorm(100 * 10), nrow = 100)
y <- matrix(rnorm(100 * 2), nrow = 100)
b <- matrix(rnorm(100 * 2), nrow = 100)
model <- xgb.train(
data = xgb.DMatrix(data = x, label = y, nthread = n_threads),
params = list(
objective = "reg:squarederror",
tree_method = "hist",
multi_strategy = "multi_output_tree",
base_score = 0,
nthread = n_threads
),
nround = 1
)
pred_only_x <- predict(model, x, nthread = n_threads, reshape = TRUE)
pred_w_base <- predict(
model,
xgb.DMatrix(data = x, base_margin = b, nthread = n_threads),
nthread = n_threads,
reshape = TRUE
)
expect_equal(pred_only_x, pred_w_base - b, tolerance = 1e-5)
})