[R] enable multi-dimensional base_margin (#9885)

This commit is contained in:
david-cortes
2023-12-14 02:16:53 +01:00
committed by GitHub
parent 936b22fdf3
commit cd473c9da3
3 changed files with 32 additions and 4 deletions

View File

@@ -349,3 +349,29 @@ test_that("xgb.DMatrix: data.frame", {
m <- xgb.DMatrix(df, enable_categorical = TRUE)
expect_equal(getinfo(m, "feature_type"), c("c", "c"))
})
test_that("xgb.DMatrix: can take multi-dimensional 'base_margin'", {
set.seed(123)
x <- matrix(rnorm(100 * 10), nrow = 100)
y <- matrix(rnorm(100 * 2), nrow = 100)
b <- matrix(rnorm(100 * 2), nrow = 100)
model <- xgb.train(
data = xgb.DMatrix(data = x, label = y, nthread = n_threads),
params = list(
objective = "reg:squarederror",
tree_method = "hist",
multi_strategy = "multi_output_tree",
base_score = 0,
nthread = n_threads
),
nround = 1
)
pred_only_x <- predict(model, x, nthread = n_threads, reshape = TRUE)
pred_w_base <- predict(
model,
xgb.DMatrix(data = x, base_margin = b, nthread = n_threads),
nthread = n_threads,
reshape = TRUE
)
expect_equal(pred_only_x, pred_w_base - b, tolerance = 1e-5)
})