[R] enable multi-dimensional base_margin (#9885)
This commit is contained in:
parent
936b22fdf3
commit
cd473c9da3
@ -16,6 +16,8 @@
|
||||
#' only care about the relative ordering of data points within each group,
|
||||
#' so it doesn't make sense to assign weights to individual data points.
|
||||
#' @param base_margin Base margin used for boosting from existing model.
|
||||
#'
|
||||
#' In the case of multi-output models, one can also pass multi-dimensional base_margin.
|
||||
#' @param missing a float value to represents missing values in data (used only when input is a dense matrix).
|
||||
#' It is useful when a 0 or some other extreme value represents missing values in data.
|
||||
#' @param silent whether to suppress printing an informational message after loading from a file.
|
||||
@ -439,9 +441,7 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) {
|
||||
return(TRUE)
|
||||
}
|
||||
if (name == "base_margin") {
|
||||
# if (length(info)!=nrow(object))
|
||||
# stop("The length of base margin must equal to the number of rows in the input data")
|
||||
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info))
|
||||
.Call(XGDMatrixSetInfo_R, object, name, info)
|
||||
return(TRUE)
|
||||
}
|
||||
if (name == "group") {
|
||||
|
||||
@ -36,7 +36,9 @@ is assigned to each group (not each data point). This is because we
|
||||
only care about the relative ordering of data points within each group,
|
||||
so it doesn't make sense to assign weights to individual data points.}
|
||||
|
||||
\item{base_margin}{Base margin used for boosting from existing model.}
|
||||
\item{base_margin}{Base margin used for boosting from existing model.
|
||||
|
||||
In the case of multi-output models, one can also pass multi-dimensional base_margin.}
|
||||
|
||||
\item{missing}{a float value to represents missing values in data (used only when input is a dense matrix).
|
||||
It is useful when a 0 or some other extreme value represents missing values in data.}
|
||||
|
||||
@ -349,3 +349,29 @@ test_that("xgb.DMatrix: data.frame", {
|
||||
m <- xgb.DMatrix(df, enable_categorical = TRUE)
|
||||
expect_equal(getinfo(m, "feature_type"), c("c", "c"))
|
||||
})
|
||||
|
||||
test_that("xgb.DMatrix: can take multi-dimensional 'base_margin'", {
|
||||
set.seed(123)
|
||||
x <- matrix(rnorm(100 * 10), nrow = 100)
|
||||
y <- matrix(rnorm(100 * 2), nrow = 100)
|
||||
b <- matrix(rnorm(100 * 2), nrow = 100)
|
||||
model <- xgb.train(
|
||||
data = xgb.DMatrix(data = x, label = y, nthread = n_threads),
|
||||
params = list(
|
||||
objective = "reg:squarederror",
|
||||
tree_method = "hist",
|
||||
multi_strategy = "multi_output_tree",
|
||||
base_score = 0,
|
||||
nthread = n_threads
|
||||
),
|
||||
nround = 1
|
||||
)
|
||||
pred_only_x <- predict(model, x, nthread = n_threads, reshape = TRUE)
|
||||
pred_w_base <- predict(
|
||||
model,
|
||||
xgb.DMatrix(data = x, base_margin = b, nthread = n_threads),
|
||||
nthread = n_threads,
|
||||
reshape = TRUE
|
||||
)
|
||||
expect_equal(pred_only_x, pred_w_base - b, tolerance = 1e-5)
|
||||
})
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user