[R] enable multi-dimensional base_margin (#9885)
This commit is contained in:
@@ -349,3 +349,29 @@ test_that("xgb.DMatrix: data.frame", {
|
||||
m <- xgb.DMatrix(df, enable_categorical = TRUE)
|
||||
expect_equal(getinfo(m, "feature_type"), c("c", "c"))
|
||||
})
|
||||
|
||||
test_that("xgb.DMatrix: can take multi-dimensional 'base_margin'", {
|
||||
set.seed(123)
|
||||
x <- matrix(rnorm(100 * 10), nrow = 100)
|
||||
y <- matrix(rnorm(100 * 2), nrow = 100)
|
||||
b <- matrix(rnorm(100 * 2), nrow = 100)
|
||||
model <- xgb.train(
|
||||
data = xgb.DMatrix(data = x, label = y, nthread = n_threads),
|
||||
params = list(
|
||||
objective = "reg:squarederror",
|
||||
tree_method = "hist",
|
||||
multi_strategy = "multi_output_tree",
|
||||
base_score = 0,
|
||||
nthread = n_threads
|
||||
),
|
||||
nround = 1
|
||||
)
|
||||
pred_only_x <- predict(model, x, nthread = n_threads, reshape = TRUE)
|
||||
pred_w_base <- predict(
|
||||
model,
|
||||
xgb.DMatrix(data = x, base_margin = b, nthread = n_threads),
|
||||
nthread = n_threads,
|
||||
reshape = TRUE
|
||||
)
|
||||
expect_equal(pred_only_x, pred_w_base - b, tolerance = 1e-5)
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user