[R] Enable multi-output objectives (#9839)

This commit is contained in:
david-cortes
2023-12-05 20:13:14 +01:00
committed by GitHub
parent 9c56916fd7
commit 62571b79eb
5 changed files with 78 additions and 27 deletions

View File

@@ -565,3 +565,54 @@ test_that("'predict' accepts CSR data", {
expect_equal(p_csc, p_csr)
expect_equal(p_csc, p_spv)
})
test_that("Can use multi-output labels with built-in objectives", {
data("mtcars")
y <- mtcars$mpg
x <- as.matrix(mtcars[, -1])
y_mirrored <- cbind(y, -y)
dm <- xgb.DMatrix(x, label = y_mirrored, nthread = n_threads)
model <- xgb.train(
params = list(
tree_method = "hist",
multi_strategy = "multi_output_tree",
objective = "reg:squarederror",
nthread = n_threads
),
data = dm,
nrounds = 5
)
pred <- predict(model, x, reshape = TRUE)
expect_equal(pred[, 1], -pred[, 2])
expect_true(cor(y, pred[, 1]) > 0.9)
expect_true(cor(y, pred[, 2]) < -0.9)
})
test_that("Can use multi-output labels with custom objectives", {
data("mtcars")
y <- mtcars$mpg
x <- as.matrix(mtcars[, -1])
y_mirrored <- cbind(y, -y)
dm <- xgb.DMatrix(x, label = y_mirrored, nthread = n_threads)
model <- xgb.train(
params = list(
tree_method = "hist",
multi_strategy = "multi_output_tree",
base_score = 0,
objective = function(pred, dtrain) {
y <- getinfo(dtrain, "label")
grad <- pred - y
hess <- rep(1, nrow(grad) * ncol(grad))
hess <- matrix(hess, nrow = nrow(grad))
return(list(grad = grad, hess = hess))
},
nthread = n_threads
),
data = dm,
nrounds = 5
)
pred <- predict(model, x, reshape = TRUE)
expect_equal(pred[, 1], -pred[, 2])
expect_true(cor(y, pred[, 1]) > 0.9)
expect_true(cor(y, pred[, 2]) < -0.9)
})