[R] Enable multi-output objectives (#9839)
This commit is contained in:
@@ -52,7 +52,7 @@ extern SEXP XGBGetGlobalConfig_R(void);
|
||||
extern SEXP XGBoosterFeatureScore_R(SEXP, SEXP);
|
||||
|
||||
static const R_CallMethodDef CallEntries[] = {
|
||||
{"XGBoosterBoostOneIter_R", (DL_FUNC) &XGBoosterTrainOneIter_R, 5},
|
||||
{"XGBoosterTrainOneIter_R", (DL_FUNC) &XGBoosterTrainOneIter_R, 5},
|
||||
{"XGBoosterCreate_R", (DL_FUNC) &XGBoosterCreate_R, 1},
|
||||
{"XGBoosterCreateInEmptyObj_R", (DL_FUNC) &XGBoosterCreateInEmptyObj_R, 2},
|
||||
{"XGBoosterDumpModel_R", (DL_FUNC) &XGBoosterDumpModel_R, 4},
|
||||
|
||||
@@ -342,9 +342,11 @@ XGB_DLL SEXP XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent) {
|
||||
XGB_DLL SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
|
||||
R_API_BEGIN();
|
||||
SEXP field_ = PROTECT(Rf_asChar(field));
|
||||
SEXP arr_dim = Rf_getAttrib(array, R_DimSymbol);
|
||||
int res_code;
|
||||
{
|
||||
const std::string array_str = MakeArrayInterfaceFromRVector(array);
|
||||
const std::string array_str = Rf_isNull(arr_dim)?
|
||||
MakeArrayInterfaceFromRVector(array) : MakeArrayInterfaceFromRMat(array);
|
||||
res_code = XGDMatrixSetInfoFromInterface(
|
||||
R_ExternalPtrAddr(handle), CHAR(field_), array_str.c_str());
|
||||
}
|
||||
@@ -513,20 +515,14 @@ XGB_DLL SEXP XGBoosterTrainOneIter_R(SEXP handle, SEXP dtrain, SEXP iter, SEXP g
|
||||
R_API_BEGIN();
|
||||
CHECK_EQ(Rf_xlength(grad), Rf_xlength(hess)) << "gradient and hess must have same length.";
|
||||
SEXP gdim = getAttrib(grad, R_DimSymbol);
|
||||
auto n_samples = static_cast<std::size_t>(INTEGER(gdim)[0]);
|
||||
auto n_targets = static_cast<std::size_t>(INTEGER(gdim)[1]);
|
||||
|
||||
SEXP hdim = getAttrib(hess, R_DimSymbol);
|
||||
CHECK_EQ(INTEGER(hdim)[0], n_samples) << "mismatched size between gradient and hessian";
|
||||
CHECK_EQ(INTEGER(hdim)[1], n_targets) << "mismatched size between gradient and hessian";
|
||||
double const *d_grad = REAL(grad);
|
||||
double const *d_hess = REAL(hess);
|
||||
|
||||
int res_code;
|
||||
{
|
||||
auto ctx = xgboost::detail::BoosterCtx(R_ExternalPtrAddr(handle));
|
||||
auto [s_grad, s_hess] = xgboost::detail::MakeGradientInterface(
|
||||
ctx, d_grad, d_hess, xgboost::linalg::kF, n_samples, n_targets);
|
||||
const std::string s_grad = Rf_isNull(gdim)?
|
||||
MakeArrayInterfaceFromRVector(grad) : MakeArrayInterfaceFromRMat(grad);
|
||||
const std::string s_hess = Rf_isNull(hdim)?
|
||||
MakeArrayInterfaceFromRVector(hess) : MakeArrayInterfaceFromRMat(hess);
|
||||
res_code = XGBoosterTrainOneIter(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(dtrain),
|
||||
asInteger(iter), s_grad.c_str(), s_hess.c_str());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user