[R] Enable multi-output objectives (#9839)
This commit is contained in:
parent
9c56916fd7
commit
62571b79eb
@ -160,23 +160,24 @@ xgb.iter.update <- function(booster_handle, dtrain, iter, obj) {
|
|||||||
)
|
)
|
||||||
gpair <- obj(pred, dtrain)
|
gpair <- obj(pred, dtrain)
|
||||||
n_samples <- dim(dtrain)[1]
|
n_samples <- dim(dtrain)[1]
|
||||||
|
grad <- gpair$grad
|
||||||
|
hess <- gpair$hess
|
||||||
|
|
||||||
msg <- paste(
|
if ((is.matrix(grad) && dim(grad)[1] != n_samples) ||
|
||||||
"Since 2.1.0, the shape of the gradient and hessian is required to be ",
|
(is.vector(grad) && length(grad) != n_samples) ||
|
||||||
"(n_samples, n_targets) or (n_samples, n_classes).",
|
(is.vector(grad) != is.vector(hess))) {
|
||||||
sep = ""
|
warning(paste(
|
||||||
)
|
"Since 2.1.0, the shape of the gradient and hessian is required to be ",
|
||||||
if (is.matrix(gpair$grad) && dim(gpair$grad)[1] != n_samples) {
|
"(n_samples, n_targets) or (n_samples, n_classes). Will reshape assuming ",
|
||||||
warning(msg)
|
"column-major order.",
|
||||||
}
|
sep = ""
|
||||||
if (is.numeric(gpair$grad) && length(gpair$grad) != n_samples) {
|
))
|
||||||
warning(msg)
|
grad <- matrix(grad, nrow = n_samples)
|
||||||
|
hess <- matrix(hess, nrow = n_samples)
|
||||||
}
|
}
|
||||||
|
|
||||||
gpair$grad <- matrix(gpair$grad, nrow = n_samples)
|
|
||||||
gpair$hess <- matrix(gpair$hess, nrow = n_samples)
|
|
||||||
.Call(
|
.Call(
|
||||||
XGBoosterBoostOneIter_R, booster_handle, dtrain, iter, gpair$grad, gpair$hess
|
XGBoosterTrainOneIter_R, booster_handle, dtrain, iter, grad, hess
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
return(TRUE)
|
return(TRUE)
|
||||||
|
|||||||
@ -243,6 +243,9 @@ getinfo.xgb.DMatrix <- function(object, name, ...) {
|
|||||||
ret <- .Call(XGDMatrixGetStrFeatureInfo_R, object, name)
|
ret <- .Call(XGDMatrixGetStrFeatureInfo_R, object, name)
|
||||||
} else if (name != "nrow") {
|
} else if (name != "nrow") {
|
||||||
ret <- .Call(XGDMatrixGetInfo_R, object, name)
|
ret <- .Call(XGDMatrixGetInfo_R, object, name)
|
||||||
|
if (length(ret) > nrow(object)) {
|
||||||
|
ret <- matrix(ret, nrow = nrow(object), byrow = TRUE)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
ret <- nrow(object)
|
ret <- nrow(object)
|
||||||
}
|
}
|
||||||
@ -286,9 +289,9 @@ setinfo <- function(object, ...) UseMethod("setinfo")
|
|||||||
#' @export
|
#' @export
|
||||||
setinfo.xgb.DMatrix <- function(object, name, info, ...) {
|
setinfo.xgb.DMatrix <- function(object, name, info, ...) {
|
||||||
if (name == "label") {
|
if (name == "label") {
|
||||||
if (length(info) != nrow(object))
|
if (NROW(info) != nrow(object))
|
||||||
stop("The length of labels must equal to the number of rows in the input data")
|
stop("The length of labels must equal to the number of rows in the input data")
|
||||||
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info))
|
.Call(XGDMatrixSetInfo_R, object, name, info)
|
||||||
return(TRUE)
|
return(TRUE)
|
||||||
}
|
}
|
||||||
if (name == "label_lower_bound") {
|
if (name == "label_lower_bound") {
|
||||||
|
|||||||
@ -52,7 +52,7 @@ extern SEXP XGBGetGlobalConfig_R(void);
|
|||||||
extern SEXP XGBoosterFeatureScore_R(SEXP, SEXP);
|
extern SEXP XGBoosterFeatureScore_R(SEXP, SEXP);
|
||||||
|
|
||||||
static const R_CallMethodDef CallEntries[] = {
|
static const R_CallMethodDef CallEntries[] = {
|
||||||
{"XGBoosterBoostOneIter_R", (DL_FUNC) &XGBoosterTrainOneIter_R, 5},
|
{"XGBoosterTrainOneIter_R", (DL_FUNC) &XGBoosterTrainOneIter_R, 5},
|
||||||
{"XGBoosterCreate_R", (DL_FUNC) &XGBoosterCreate_R, 1},
|
{"XGBoosterCreate_R", (DL_FUNC) &XGBoosterCreate_R, 1},
|
||||||
{"XGBoosterCreateInEmptyObj_R", (DL_FUNC) &XGBoosterCreateInEmptyObj_R, 2},
|
{"XGBoosterCreateInEmptyObj_R", (DL_FUNC) &XGBoosterCreateInEmptyObj_R, 2},
|
||||||
{"XGBoosterDumpModel_R", (DL_FUNC) &XGBoosterDumpModel_R, 4},
|
{"XGBoosterDumpModel_R", (DL_FUNC) &XGBoosterDumpModel_R, 4},
|
||||||
|
|||||||
@ -342,9 +342,11 @@ XGB_DLL SEXP XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent) {
|
|||||||
XGB_DLL SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
|
XGB_DLL SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
|
||||||
R_API_BEGIN();
|
R_API_BEGIN();
|
||||||
SEXP field_ = PROTECT(Rf_asChar(field));
|
SEXP field_ = PROTECT(Rf_asChar(field));
|
||||||
|
SEXP arr_dim = Rf_getAttrib(array, R_DimSymbol);
|
||||||
int res_code;
|
int res_code;
|
||||||
{
|
{
|
||||||
const std::string array_str = MakeArrayInterfaceFromRVector(array);
|
const std::string array_str = Rf_isNull(arr_dim)?
|
||||||
|
MakeArrayInterfaceFromRVector(array) : MakeArrayInterfaceFromRMat(array);
|
||||||
res_code = XGDMatrixSetInfoFromInterface(
|
res_code = XGDMatrixSetInfoFromInterface(
|
||||||
R_ExternalPtrAddr(handle), CHAR(field_), array_str.c_str());
|
R_ExternalPtrAddr(handle), CHAR(field_), array_str.c_str());
|
||||||
}
|
}
|
||||||
@ -513,20 +515,14 @@ XGB_DLL SEXP XGBoosterTrainOneIter_R(SEXP handle, SEXP dtrain, SEXP iter, SEXP g
|
|||||||
R_API_BEGIN();
|
R_API_BEGIN();
|
||||||
CHECK_EQ(Rf_xlength(grad), Rf_xlength(hess)) << "gradient and hess must have same length.";
|
CHECK_EQ(Rf_xlength(grad), Rf_xlength(hess)) << "gradient and hess must have same length.";
|
||||||
SEXP gdim = getAttrib(grad, R_DimSymbol);
|
SEXP gdim = getAttrib(grad, R_DimSymbol);
|
||||||
auto n_samples = static_cast<std::size_t>(INTEGER(gdim)[0]);
|
|
||||||
auto n_targets = static_cast<std::size_t>(INTEGER(gdim)[1]);
|
|
||||||
|
|
||||||
SEXP hdim = getAttrib(hess, R_DimSymbol);
|
SEXP hdim = getAttrib(hess, R_DimSymbol);
|
||||||
CHECK_EQ(INTEGER(hdim)[0], n_samples) << "mismatched size between gradient and hessian";
|
|
||||||
CHECK_EQ(INTEGER(hdim)[1], n_targets) << "mismatched size between gradient and hessian";
|
|
||||||
double const *d_grad = REAL(grad);
|
|
||||||
double const *d_hess = REAL(hess);
|
|
||||||
|
|
||||||
int res_code;
|
int res_code;
|
||||||
{
|
{
|
||||||
auto ctx = xgboost::detail::BoosterCtx(R_ExternalPtrAddr(handle));
|
const std::string s_grad = Rf_isNull(gdim)?
|
||||||
auto [s_grad, s_hess] = xgboost::detail::MakeGradientInterface(
|
MakeArrayInterfaceFromRVector(grad) : MakeArrayInterfaceFromRMat(grad);
|
||||||
ctx, d_grad, d_hess, xgboost::linalg::kF, n_samples, n_targets);
|
const std::string s_hess = Rf_isNull(hdim)?
|
||||||
|
MakeArrayInterfaceFromRVector(hess) : MakeArrayInterfaceFromRMat(hess);
|
||||||
res_code = XGBoosterTrainOneIter(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(dtrain),
|
res_code = XGBoosterTrainOneIter(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(dtrain),
|
||||||
asInteger(iter), s_grad.c_str(), s_hess.c_str());
|
asInteger(iter), s_grad.c_str(), s_hess.c_str());
|
||||||
}
|
}
|
||||||
|
|||||||
@ -565,3 +565,54 @@ test_that("'predict' accepts CSR data", {
|
|||||||
expect_equal(p_csc, p_csr)
|
expect_equal(p_csc, p_csr)
|
||||||
expect_equal(p_csc, p_spv)
|
expect_equal(p_csc, p_spv)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
test_that("Can use multi-output labels with built-in objectives", {
|
||||||
|
data("mtcars")
|
||||||
|
y <- mtcars$mpg
|
||||||
|
x <- as.matrix(mtcars[, -1])
|
||||||
|
y_mirrored <- cbind(y, -y)
|
||||||
|
dm <- xgb.DMatrix(x, label = y_mirrored, nthread = n_threads)
|
||||||
|
model <- xgb.train(
|
||||||
|
params = list(
|
||||||
|
tree_method = "hist",
|
||||||
|
multi_strategy = "multi_output_tree",
|
||||||
|
objective = "reg:squarederror",
|
||||||
|
nthread = n_threads
|
||||||
|
),
|
||||||
|
data = dm,
|
||||||
|
nrounds = 5
|
||||||
|
)
|
||||||
|
pred <- predict(model, x, reshape = TRUE)
|
||||||
|
expect_equal(pred[, 1], -pred[, 2])
|
||||||
|
expect_true(cor(y, pred[, 1]) > 0.9)
|
||||||
|
expect_true(cor(y, pred[, 2]) < -0.9)
|
||||||
|
})
|
||||||
|
|
||||||
|
test_that("Can use multi-output labels with custom objectives", {
|
||||||
|
data("mtcars")
|
||||||
|
y <- mtcars$mpg
|
||||||
|
x <- as.matrix(mtcars[, -1])
|
||||||
|
y_mirrored <- cbind(y, -y)
|
||||||
|
dm <- xgb.DMatrix(x, label = y_mirrored, nthread = n_threads)
|
||||||
|
model <- xgb.train(
|
||||||
|
params = list(
|
||||||
|
tree_method = "hist",
|
||||||
|
multi_strategy = "multi_output_tree",
|
||||||
|
base_score = 0,
|
||||||
|
objective = function(pred, dtrain) {
|
||||||
|
y <- getinfo(dtrain, "label")
|
||||||
|
grad <- pred - y
|
||||||
|
hess <- rep(1, nrow(grad) * ncol(grad))
|
||||||
|
hess <- matrix(hess, nrow = nrow(grad))
|
||||||
|
return(list(grad = grad, hess = hess))
|
||||||
|
},
|
||||||
|
nthread = n_threads
|
||||||
|
),
|
||||||
|
data = dm,
|
||||||
|
nrounds = 5
|
||||||
|
)
|
||||||
|
pred <- predict(model, x, reshape = TRUE)
|
||||||
|
expect_equal(pred[, 1], -pred[, 2])
|
||||||
|
expect_true(cor(y, pred[, 1]) > 0.9)
|
||||||
|
expect_true(cor(y, pred[, 2]) < -0.9)
|
||||||
|
})
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user