[R] Enable multi-output objectives (#9839)
This commit is contained in:
@@ -160,23 +160,24 @@ xgb.iter.update <- function(booster_handle, dtrain, iter, obj) {
|
||||
)
|
||||
gpair <- obj(pred, dtrain)
|
||||
n_samples <- dim(dtrain)[1]
|
||||
grad <- gpair$grad
|
||||
hess <- gpair$hess
|
||||
|
||||
msg <- paste(
|
||||
"Since 2.1.0, the shape of the gradient and hessian is required to be ",
|
||||
"(n_samples, n_targets) or (n_samples, n_classes).",
|
||||
sep = ""
|
||||
)
|
||||
if (is.matrix(gpair$grad) && dim(gpair$grad)[1] != n_samples) {
|
||||
warning(msg)
|
||||
}
|
||||
if (is.numeric(gpair$grad) && length(gpair$grad) != n_samples) {
|
||||
warning(msg)
|
||||
if ((is.matrix(grad) && dim(grad)[1] != n_samples) ||
|
||||
(is.vector(grad) && length(grad) != n_samples) ||
|
||||
(is.vector(grad) != is.vector(hess))) {
|
||||
warning(paste(
|
||||
"Since 2.1.0, the shape of the gradient and hessian is required to be ",
|
||||
"(n_samples, n_targets) or (n_samples, n_classes). Will reshape assuming ",
|
||||
"column-major order.",
|
||||
sep = ""
|
||||
))
|
||||
grad <- matrix(grad, nrow = n_samples)
|
||||
hess <- matrix(hess, nrow = n_samples)
|
||||
}
|
||||
|
||||
gpair$grad <- matrix(gpair$grad, nrow = n_samples)
|
||||
gpair$hess <- matrix(gpair$hess, nrow = n_samples)
|
||||
.Call(
|
||||
XGBoosterBoostOneIter_R, booster_handle, dtrain, iter, gpair$grad, gpair$hess
|
||||
XGBoosterTrainOneIter_R, booster_handle, dtrain, iter, grad, hess
|
||||
)
|
||||
}
|
||||
return(TRUE)
|
||||
|
||||
@@ -243,6 +243,9 @@ getinfo.xgb.DMatrix <- function(object, name, ...) {
|
||||
ret <- .Call(XGDMatrixGetStrFeatureInfo_R, object, name)
|
||||
} else if (name != "nrow") {
|
||||
ret <- .Call(XGDMatrixGetInfo_R, object, name)
|
||||
if (length(ret) > nrow(object)) {
|
||||
ret <- matrix(ret, nrow = nrow(object), byrow = TRUE)
|
||||
}
|
||||
} else {
|
||||
ret <- nrow(object)
|
||||
}
|
||||
@@ -286,9 +289,9 @@ setinfo <- function(object, ...) UseMethod("setinfo")
|
||||
#' @export
|
||||
setinfo.xgb.DMatrix <- function(object, name, info, ...) {
|
||||
if (name == "label") {
|
||||
if (length(info) != nrow(object))
|
||||
if (NROW(info) != nrow(object))
|
||||
stop("The length of labels must equal to the number of rows in the input data")
|
||||
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info))
|
||||
.Call(XGDMatrixSetInfo_R, object, name, info)
|
||||
return(TRUE)
|
||||
}
|
||||
if (name == "label_lower_bound") {
|
||||
|
||||
Reference in New Issue
Block a user