[R] Enable multi-output objectives (#9839)

This commit is contained in:
david-cortes 2023-12-05 20:13:14 +01:00 committed by GitHub
parent 9c56916fd7
commit 62571b79eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 78 additions and 27 deletions

View File

@ -160,23 +160,24 @@ xgb.iter.update <- function(booster_handle, dtrain, iter, obj) {
) )
gpair <- obj(pred, dtrain) gpair <- obj(pred, dtrain)
n_samples <- dim(dtrain)[1] n_samples <- dim(dtrain)[1]
grad <- gpair$grad
hess <- gpair$hess
msg <- paste( if ((is.matrix(grad) && dim(grad)[1] != n_samples) ||
"Since 2.1.0, the shape of the gradient and hessian is required to be ", (is.vector(grad) && length(grad) != n_samples) ||
"(n_samples, n_targets) or (n_samples, n_classes).", (is.vector(grad) != is.vector(hess))) {
sep = "" warning(paste(
) "Since 2.1.0, the shape of the gradient and hessian is required to be ",
if (is.matrix(gpair$grad) && dim(gpair$grad)[1] != n_samples) { "(n_samples, n_targets) or (n_samples, n_classes). Will reshape assuming ",
warning(msg) "column-major order.",
} sep = ""
if (is.numeric(gpair$grad) && length(gpair$grad) != n_samples) { ))
warning(msg) grad <- matrix(grad, nrow = n_samples)
hess <- matrix(hess, nrow = n_samples)
} }
gpair$grad <- matrix(gpair$grad, nrow = n_samples)
gpair$hess <- matrix(gpair$hess, nrow = n_samples)
.Call( .Call(
XGBoosterBoostOneIter_R, booster_handle, dtrain, iter, gpair$grad, gpair$hess XGBoosterTrainOneIter_R, booster_handle, dtrain, iter, grad, hess
) )
} }
return(TRUE) return(TRUE)

View File

@ -243,6 +243,9 @@ getinfo.xgb.DMatrix <- function(object, name, ...) {
ret <- .Call(XGDMatrixGetStrFeatureInfo_R, object, name) ret <- .Call(XGDMatrixGetStrFeatureInfo_R, object, name)
} else if (name != "nrow") { } else if (name != "nrow") {
ret <- .Call(XGDMatrixGetInfo_R, object, name) ret <- .Call(XGDMatrixGetInfo_R, object, name)
if (length(ret) > nrow(object)) {
ret <- matrix(ret, nrow = nrow(object), byrow = TRUE)
}
} else { } else {
ret <- nrow(object) ret <- nrow(object)
} }
@ -286,9 +289,9 @@ setinfo <- function(object, ...) UseMethod("setinfo")
#' @export #' @export
setinfo.xgb.DMatrix <- function(object, name, info, ...) { setinfo.xgb.DMatrix <- function(object, name, info, ...) {
if (name == "label") { if (name == "label") {
if (length(info) != nrow(object)) if (NROW(info) != nrow(object))
stop("The length of labels must equal to the number of rows in the input data") stop("The length of labels must equal to the number of rows in the input data")
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info)) .Call(XGDMatrixSetInfo_R, object, name, info)
return(TRUE) return(TRUE)
} }
if (name == "label_lower_bound") { if (name == "label_lower_bound") {

View File

@ -52,7 +52,7 @@ extern SEXP XGBGetGlobalConfig_R(void);
extern SEXP XGBoosterFeatureScore_R(SEXP, SEXP); extern SEXP XGBoosterFeatureScore_R(SEXP, SEXP);
static const R_CallMethodDef CallEntries[] = { static const R_CallMethodDef CallEntries[] = {
{"XGBoosterBoostOneIter_R", (DL_FUNC) &XGBoosterTrainOneIter_R, 5}, {"XGBoosterTrainOneIter_R", (DL_FUNC) &XGBoosterTrainOneIter_R, 5},
{"XGBoosterCreate_R", (DL_FUNC) &XGBoosterCreate_R, 1}, {"XGBoosterCreate_R", (DL_FUNC) &XGBoosterCreate_R, 1},
{"XGBoosterCreateInEmptyObj_R", (DL_FUNC) &XGBoosterCreateInEmptyObj_R, 2}, {"XGBoosterCreateInEmptyObj_R", (DL_FUNC) &XGBoosterCreateInEmptyObj_R, 2},
{"XGBoosterDumpModel_R", (DL_FUNC) &XGBoosterDumpModel_R, 4}, {"XGBoosterDumpModel_R", (DL_FUNC) &XGBoosterDumpModel_R, 4},

View File

@ -342,9 +342,11 @@ XGB_DLL SEXP XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent) {
XGB_DLL SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) { XGB_DLL SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
R_API_BEGIN(); R_API_BEGIN();
SEXP field_ = PROTECT(Rf_asChar(field)); SEXP field_ = PROTECT(Rf_asChar(field));
SEXP arr_dim = Rf_getAttrib(array, R_DimSymbol);
int res_code; int res_code;
{ {
const std::string array_str = MakeArrayInterfaceFromRVector(array); const std::string array_str = Rf_isNull(arr_dim)?
MakeArrayInterfaceFromRVector(array) : MakeArrayInterfaceFromRMat(array);
res_code = XGDMatrixSetInfoFromInterface( res_code = XGDMatrixSetInfoFromInterface(
R_ExternalPtrAddr(handle), CHAR(field_), array_str.c_str()); R_ExternalPtrAddr(handle), CHAR(field_), array_str.c_str());
} }
@ -513,20 +515,14 @@ XGB_DLL SEXP XGBoosterTrainOneIter_R(SEXP handle, SEXP dtrain, SEXP iter, SEXP g
R_API_BEGIN(); R_API_BEGIN();
CHECK_EQ(Rf_xlength(grad), Rf_xlength(hess)) << "gradient and hess must have same length."; CHECK_EQ(Rf_xlength(grad), Rf_xlength(hess)) << "gradient and hess must have same length.";
SEXP gdim = getAttrib(grad, R_DimSymbol); SEXP gdim = getAttrib(grad, R_DimSymbol);
auto n_samples = static_cast<std::size_t>(INTEGER(gdim)[0]);
auto n_targets = static_cast<std::size_t>(INTEGER(gdim)[1]);
SEXP hdim = getAttrib(hess, R_DimSymbol); SEXP hdim = getAttrib(hess, R_DimSymbol);
CHECK_EQ(INTEGER(hdim)[0], n_samples) << "mismatched size between gradient and hessian";
CHECK_EQ(INTEGER(hdim)[1], n_targets) << "mismatched size between gradient and hessian";
double const *d_grad = REAL(grad);
double const *d_hess = REAL(hess);
int res_code; int res_code;
{ {
auto ctx = xgboost::detail::BoosterCtx(R_ExternalPtrAddr(handle)); const std::string s_grad = Rf_isNull(gdim)?
auto [s_grad, s_hess] = xgboost::detail::MakeGradientInterface( MakeArrayInterfaceFromRVector(grad) : MakeArrayInterfaceFromRMat(grad);
ctx, d_grad, d_hess, xgboost::linalg::kF, n_samples, n_targets); const std::string s_hess = Rf_isNull(hdim)?
MakeArrayInterfaceFromRVector(hess) : MakeArrayInterfaceFromRMat(hess);
res_code = XGBoosterTrainOneIter(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(dtrain), res_code = XGBoosterTrainOneIter(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(dtrain),
asInteger(iter), s_grad.c_str(), s_hess.c_str()); asInteger(iter), s_grad.c_str(), s_hess.c_str());
} }

View File

@ -565,3 +565,54 @@ test_that("'predict' accepts CSR data", {
expect_equal(p_csc, p_csr) expect_equal(p_csc, p_csr)
expect_equal(p_csc, p_spv) expect_equal(p_csc, p_spv)
}) })
test_that("Can use multi-output labels with built-in objectives", {
data("mtcars")
y <- mtcars$mpg
x <- as.matrix(mtcars[, -1])
y_mirrored <- cbind(y, -y)
dm <- xgb.DMatrix(x, label = y_mirrored, nthread = n_threads)
model <- xgb.train(
params = list(
tree_method = "hist",
multi_strategy = "multi_output_tree",
objective = "reg:squarederror",
nthread = n_threads
),
data = dm,
nrounds = 5
)
pred <- predict(model, x, reshape = TRUE)
expect_equal(pred[, 1], -pred[, 2])
expect_true(cor(y, pred[, 1]) > 0.9)
expect_true(cor(y, pred[, 2]) < -0.9)
})
test_that("Can use multi-output labels with custom objectives", {
data("mtcars")
y <- mtcars$mpg
x <- as.matrix(mtcars[, -1])
y_mirrored <- cbind(y, -y)
dm <- xgb.DMatrix(x, label = y_mirrored, nthread = n_threads)
model <- xgb.train(
params = list(
tree_method = "hist",
multi_strategy = "multi_output_tree",
base_score = 0,
objective = function(pred, dtrain) {
y <- getinfo(dtrain, "label")
grad <- pred - y
hess <- rep(1, nrow(grad) * ncol(grad))
hess <- matrix(hess, nrow = nrow(grad))
return(list(grad = grad, hess = hess))
},
nthread = n_threads
),
data = dm,
nrounds = 5
)
pred <- predict(model, x, reshape = TRUE)
expect_equal(pred[, 1], -pred[, 2])
expect_true(cor(y, pred[, 1]) > 0.9)
expect_true(cor(y, pred[, 2]) < -0.9)
})