[R] enable multi-dimensional base_margin (#9885)
This commit is contained in:
parent
936b22fdf3
commit
cd473c9da3
@ -16,6 +16,8 @@
|
|||||||
#' only care about the relative ordering of data points within each group,
|
#' only care about the relative ordering of data points within each group,
|
||||||
#' so it doesn't make sense to assign weights to individual data points.
|
#' so it doesn't make sense to assign weights to individual data points.
|
||||||
#' @param base_margin Base margin used for boosting from existing model.
|
#' @param base_margin Base margin used for boosting from existing model.
|
||||||
|
#'
|
||||||
|
#' In the case of multi-output models, one can also pass multi-dimensional base_margin.
|
||||||
#' @param missing a float value to represents missing values in data (used only when input is a dense matrix).
|
#' @param missing a float value to represents missing values in data (used only when input is a dense matrix).
|
||||||
#' It is useful when a 0 or some other extreme value represents missing values in data.
|
#' It is useful when a 0 or some other extreme value represents missing values in data.
|
||||||
#' @param silent whether to suppress printing an informational message after loading from a file.
|
#' @param silent whether to suppress printing an informational message after loading from a file.
|
||||||
@ -439,9 +441,7 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) {
|
|||||||
return(TRUE)
|
return(TRUE)
|
||||||
}
|
}
|
||||||
if (name == "base_margin") {
|
if (name == "base_margin") {
|
||||||
# if (length(info)!=nrow(object))
|
.Call(XGDMatrixSetInfo_R, object, name, info)
|
||||||
# stop("The length of base margin must equal to the number of rows in the input data")
|
|
||||||
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info))
|
|
||||||
return(TRUE)
|
return(TRUE)
|
||||||
}
|
}
|
||||||
if (name == "group") {
|
if (name == "group") {
|
||||||
|
|||||||
@ -36,7 +36,9 @@ is assigned to each group (not each data point). This is because we
|
|||||||
only care about the relative ordering of data points within each group,
|
only care about the relative ordering of data points within each group,
|
||||||
so it doesn't make sense to assign weights to individual data points.}
|
so it doesn't make sense to assign weights to individual data points.}
|
||||||
|
|
||||||
\item{base_margin}{Base margin used for boosting from existing model.}
|
\item{base_margin}{Base margin used for boosting from existing model.
|
||||||
|
|
||||||
|
In the case of multi-output models, one can also pass multi-dimensional base_margin.}
|
||||||
|
|
||||||
\item{missing}{a float value to represents missing values in data (used only when input is a dense matrix).
|
\item{missing}{a float value to represents missing values in data (used only when input is a dense matrix).
|
||||||
It is useful when a 0 or some other extreme value represents missing values in data.}
|
It is useful when a 0 or some other extreme value represents missing values in data.}
|
||||||
|
|||||||
@ -349,3 +349,29 @@ test_that("xgb.DMatrix: data.frame", {
|
|||||||
m <- xgb.DMatrix(df, enable_categorical = TRUE)
|
m <- xgb.DMatrix(df, enable_categorical = TRUE)
|
||||||
expect_equal(getinfo(m, "feature_type"), c("c", "c"))
|
expect_equal(getinfo(m, "feature_type"), c("c", "c"))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test_that("xgb.DMatrix: can take multi-dimensional 'base_margin'", {
|
||||||
|
set.seed(123)
|
||||||
|
x <- matrix(rnorm(100 * 10), nrow = 100)
|
||||||
|
y <- matrix(rnorm(100 * 2), nrow = 100)
|
||||||
|
b <- matrix(rnorm(100 * 2), nrow = 100)
|
||||||
|
model <- xgb.train(
|
||||||
|
data = xgb.DMatrix(data = x, label = y, nthread = n_threads),
|
||||||
|
params = list(
|
||||||
|
objective = "reg:squarederror",
|
||||||
|
tree_method = "hist",
|
||||||
|
multi_strategy = "multi_output_tree",
|
||||||
|
base_score = 0,
|
||||||
|
nthread = n_threads
|
||||||
|
),
|
||||||
|
nround = 1
|
||||||
|
)
|
||||||
|
pred_only_x <- predict(model, x, nthread = n_threads, reshape = TRUE)
|
||||||
|
pred_w_base <- predict(
|
||||||
|
model,
|
||||||
|
xgb.DMatrix(data = x, base_margin = b, nthread = n_threads),
|
||||||
|
nthread = n_threads,
|
||||||
|
reshape = TRUE
|
||||||
|
)
|
||||||
|
expect_equal(pred_only_x, pred_w_base - b, tolerance = 1e-5)
|
||||||
|
})
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user