[R] Enable multi-output objectives (#9839)

This commit is contained in:
david-cortes
2023-12-05 20:13:14 +01:00
committed by GitHub
parent 9c56916fd7
commit 62571b79eb
5 changed files with 78 additions and 27 deletions

View File

@@ -52,7 +52,7 @@ extern SEXP XGBGetGlobalConfig_R(void);
extern SEXP XGBoosterFeatureScore_R(SEXP, SEXP);
static const R_CallMethodDef CallEntries[] = {
{"XGBoosterBoostOneIter_R", (DL_FUNC) &XGBoosterTrainOneIter_R, 5},
{"XGBoosterTrainOneIter_R", (DL_FUNC) &XGBoosterTrainOneIter_R, 5},
{"XGBoosterCreate_R", (DL_FUNC) &XGBoosterCreate_R, 1},
{"XGBoosterCreateInEmptyObj_R", (DL_FUNC) &XGBoosterCreateInEmptyObj_R, 2},
{"XGBoosterDumpModel_R", (DL_FUNC) &XGBoosterDumpModel_R, 4},

View File

@@ -342,9 +342,11 @@ XGB_DLL SEXP XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent) {
XGB_DLL SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
R_API_BEGIN();
SEXP field_ = PROTECT(Rf_asChar(field));
SEXP arr_dim = Rf_getAttrib(array, R_DimSymbol);
int res_code;
{
const std::string array_str = MakeArrayInterfaceFromRVector(array);
const std::string array_str = Rf_isNull(arr_dim)?
MakeArrayInterfaceFromRVector(array) : MakeArrayInterfaceFromRMat(array);
res_code = XGDMatrixSetInfoFromInterface(
R_ExternalPtrAddr(handle), CHAR(field_), array_str.c_str());
}
@@ -513,20 +515,14 @@ XGB_DLL SEXP XGBoosterTrainOneIter_R(SEXP handle, SEXP dtrain, SEXP iter, SEXP g
R_API_BEGIN();
CHECK_EQ(Rf_xlength(grad), Rf_xlength(hess)) << "gradient and hess must have same length.";
SEXP gdim = getAttrib(grad, R_DimSymbol);
auto n_samples = static_cast<std::size_t>(INTEGER(gdim)[0]);
auto n_targets = static_cast<std::size_t>(INTEGER(gdim)[1]);
SEXP hdim = getAttrib(hess, R_DimSymbol);
CHECK_EQ(INTEGER(hdim)[0], n_samples) << "mismatched size between gradient and hessian";
CHECK_EQ(INTEGER(hdim)[1], n_targets) << "mismatched size between gradient and hessian";
double const *d_grad = REAL(grad);
double const *d_hess = REAL(hess);
int res_code;
{
auto ctx = xgboost::detail::BoosterCtx(R_ExternalPtrAddr(handle));
auto [s_grad, s_hess] = xgboost::detail::MakeGradientInterface(
ctx, d_grad, d_hess, xgboost::linalg::kF, n_samples, n_targets);
const std::string s_grad = Rf_isNull(gdim)?
MakeArrayInterfaceFromRVector(grad) : MakeArrayInterfaceFromRMat(grad);
const std::string s_hess = Rf_isNull(hdim)?
MakeArrayInterfaceFromRVector(hess) : MakeArrayInterfaceFromRMat(hess);
res_code = XGBoosterTrainOneIter(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(dtrain),
asInteger(iter), s_grad.c_str(), s_hess.c_str());
}