[R] Support multi-class custom objective. (#9526)
This commit is contained in:
@@ -64,23 +64,80 @@ test_that("custom objective using DMatrix attr works", {
|
||||
expect_equal(class(bst), "xgb.Booster")
|
||||
})
|
||||
|
||||
test_that("custom objective with multi-class works", {
|
||||
test_that("custom objective with multi-class shape", {
|
||||
data <- as.matrix(iris[, -5])
|
||||
label <- as.numeric(iris$Species) - 1
|
||||
dtrain <- xgb.DMatrix(data = data, label = label)
|
||||
nclasses <- 3
|
||||
n_classes <- 3
|
||||
|
||||
fake_softprob <- function(preds, dtrain) {
|
||||
expect_true(all(matrix(preds) == 0.5))
|
||||
grad <- rnorm(dim(as.matrix(preds))[1])
|
||||
expect_equal(dim(data)[1] * nclasses, dim(as.matrix(preds))[1])
|
||||
hess <- rnorm(dim(as.matrix(preds))[1])
|
||||
return (list(grad = grad, hess = hess))
|
||||
## use numeric vector here to test compatibility with XGBoost < 2.1
|
||||
grad <- rnorm(length(as.matrix(preds)))
|
||||
expect_equal(dim(data)[1] * n_classes, dim(as.matrix(preds))[1] * n_classes)
|
||||
hess <- rnorm(length(as.matrix(preds)))
|
||||
return(list(grad = grad, hess = hess))
|
||||
}
|
||||
fake_merror <- function(preds, dtrain) {
|
||||
expect_equal(dim(data)[1] * nclasses, dim(as.matrix(preds))[1])
|
||||
expect_equal(dim(data)[1] * n_classes, dim(as.matrix(preds))[1])
|
||||
}
|
||||
param$objective <- fake_softprob
|
||||
param$eval_metric <- fake_merror
|
||||
bst <- xgb.train(param, dtrain, 1, num_class = nclasses)
|
||||
bst <- xgb.train(param, dtrain, 1, num_class = n_classes)
|
||||
})
|
||||
|
||||
softmax <- function(values) {
|
||||
values <- as.numeric(values)
|
||||
exps <- exp(values)
|
||||
den <- sum(exps)
|
||||
return(exps / den)
|
||||
}
|
||||
|
||||
softprob <- function(predt, dtrain) {
|
||||
y <- getinfo(dtrain, "label")
|
||||
|
||||
n_samples <- dim(predt)[1]
|
||||
n_classes <- dim(predt)[2]
|
||||
|
||||
grad <- matrix(nrow = n_samples, ncol = n_classes)
|
||||
hess <- matrix(nrow = n_samples, ncol = n_classes)
|
||||
|
||||
for (i in seq_len(n_samples)) {
|
||||
t <- y[i]
|
||||
p <- softmax(predt[i, ])
|
||||
for (c in seq_len(n_classes)) {
|
||||
g <- if (c - 1 == t) {
|
||||
p[c] - 1.0
|
||||
} else {
|
||||
p[c]
|
||||
}
|
||||
h <- max((2.0 * p[c] * (1.0 - p[c])), 1e-6)
|
||||
grad[i, c] <- g
|
||||
hess[i, c] <- h
|
||||
}
|
||||
}
|
||||
|
||||
return(list(grad = grad, hess = hess))
|
||||
}
|
||||
|
||||
|
||||
test_that("custom objective with multi-class works", {
|
||||
data <- as.matrix(iris[, -5])
|
||||
label <- as.numeric(iris$Species) - 1
|
||||
|
||||
dtrain <- xgb.DMatrix(data = data, label = label)
|
||||
|
||||
param$num_class <- 3
|
||||
param$objective <- softprob
|
||||
param$eval_metric <- "merror"
|
||||
param$base_score <- 0.5
|
||||
|
||||
custom_bst <- xgb.train(param, dtrain, 2)
|
||||
custom_predt <- predict(custom_bst, dtrain)
|
||||
|
||||
param$objective <- "multi:softmax"
|
||||
builtin_bst <- xgb.train(param, dtrain, 2)
|
||||
builtin_predt <- predict(builtin_bst, dtrain)
|
||||
|
||||
expect_equal(custom_predt, builtin_predt)
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user