[R] Support multi-class custom objective. (#9526)
This commit is contained in:
@@ -151,14 +151,30 @@ xgb.iter.update <- function(booster_handle, dtrain, iter, obj) {
|
||||
if (is.null(obj)) {
|
||||
.Call(XGBoosterUpdateOneIter_R, booster_handle, as.integer(iter), dtrain)
|
||||
} else {
|
||||
pred <- predict(booster_handle, dtrain, outputmargin = TRUE, training = TRUE,
|
||||
ntreelimit = 0)
|
||||
pred <- predict(
|
||||
booster_handle,
|
||||
dtrain,
|
||||
outputmargin = TRUE,
|
||||
training = TRUE,
|
||||
reshape = TRUE
|
||||
)
|
||||
gpair <- obj(pred, dtrain)
|
||||
n_samples <- dim(dtrain)[1]
|
||||
# We still require row-major in R as I'm not quite sure sure how to get the stride of
|
||||
# the matrix in C.
|
||||
gpair$grad <- matrix(gpair$grad, nrow = n_samples, byrow = TRUE)
|
||||
gpair$hess <- matrix(gpair$hess, nrow = n_samples, byrow = TRUE)
|
||||
|
||||
msg <- paste(
|
||||
"Since 2.1.0, the shape of the gradient and hessian is required to be ",
|
||||
"(n_samples, n_targets) or (n_samples, n_classes).",
|
||||
sep = ""
|
||||
)
|
||||
if (is.matrix(gpair$grad) && dim(gpair$grad)[1] != n_samples) {
|
||||
warning(msg)
|
||||
}
|
||||
if (is.numeric(gpair$grad) && length(gpair$grad) != n_samples) {
|
||||
warning(msg)
|
||||
}
|
||||
|
||||
gpair$grad <- matrix(gpair$grad, nrow = n_samples)
|
||||
gpair$hess <- matrix(gpair$hess, nrow = n_samples)
|
||||
.Call(
|
||||
XGBoosterBoostOneIter_R, booster_handle, dtrain, iter, gpair$grad, gpair$hess
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user