[R] Enable multi-output objectives (#9839)
This commit is contained in:
@@ -565,3 +565,54 @@ test_that("'predict' accepts CSR data", {
|
||||
expect_equal(p_csc, p_csr)
|
||||
expect_equal(p_csc, p_spv)
|
||||
})
|
||||
|
||||
test_that("Can use multi-output labels with built-in objectives", {
|
||||
data("mtcars")
|
||||
y <- mtcars$mpg
|
||||
x <- as.matrix(mtcars[, -1])
|
||||
y_mirrored <- cbind(y, -y)
|
||||
dm <- xgb.DMatrix(x, label = y_mirrored, nthread = n_threads)
|
||||
model <- xgb.train(
|
||||
params = list(
|
||||
tree_method = "hist",
|
||||
multi_strategy = "multi_output_tree",
|
||||
objective = "reg:squarederror",
|
||||
nthread = n_threads
|
||||
),
|
||||
data = dm,
|
||||
nrounds = 5
|
||||
)
|
||||
pred <- predict(model, x, reshape = TRUE)
|
||||
expect_equal(pred[, 1], -pred[, 2])
|
||||
expect_true(cor(y, pred[, 1]) > 0.9)
|
||||
expect_true(cor(y, pred[, 2]) < -0.9)
|
||||
})
|
||||
|
||||
test_that("Can use multi-output labels with custom objectives", {
|
||||
data("mtcars")
|
||||
y <- mtcars$mpg
|
||||
x <- as.matrix(mtcars[, -1])
|
||||
y_mirrored <- cbind(y, -y)
|
||||
dm <- xgb.DMatrix(x, label = y_mirrored, nthread = n_threads)
|
||||
model <- xgb.train(
|
||||
params = list(
|
||||
tree_method = "hist",
|
||||
multi_strategy = "multi_output_tree",
|
||||
base_score = 0,
|
||||
objective = function(pred, dtrain) {
|
||||
y <- getinfo(dtrain, "label")
|
||||
grad <- pred - y
|
||||
hess <- rep(1, nrow(grad) * ncol(grad))
|
||||
hess <- matrix(hess, nrow = nrow(grad))
|
||||
return(list(grad = grad, hess = hess))
|
||||
},
|
||||
nthread = n_threads
|
||||
),
|
||||
data = dm,
|
||||
nrounds = 5
|
||||
)
|
||||
pred <- predict(model, x, reshape = TRUE)
|
||||
expect_equal(pred[, 1], -pred[, 2])
|
||||
expect_true(cor(y, pred[, 1]) > 0.9)
|
||||
expect_true(cor(y, pred[, 2]) < -0.9)
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user