[R] Support multi-class custom objective. (#9526)

This commit is contained in:
Jiaming Yuan
2023-08-29 08:27:13 +08:00
committed by GitHub
parent 90ef250ea1
commit be6a552956
6 changed files with 106 additions and 26 deletions

View File

@@ -403,7 +403,7 @@ XGB_DLL SEXP XGBoosterUpdateOneIter_R(SEXP handle, SEXP iter, SEXP dtrain) {
XGB_DLL SEXP XGBoosterTrainOneIter_R(SEXP handle, SEXP dtrain, SEXP iter, SEXP grad, SEXP hess) {
R_API_BEGIN();
CHECK_EQ(length(grad), length(hess)) << "gradient and hess must have same length";
CHECK_EQ(length(grad), length(hess)) << "gradient and hess must have same length.";
SEXP gdim = getAttrib(grad, R_DimSymbol);
auto n_samples = static_cast<std::size_t>(INTEGER(gdim)[0]);
auto n_targets = static_cast<std::size_t>(INTEGER(gdim)[1]);
@@ -415,8 +415,8 @@ XGB_DLL SEXP XGBoosterTrainOneIter_R(SEXP handle, SEXP dtrain, SEXP iter, SEXP g
double const *d_hess = REAL(hess);
auto ctx = xgboost::detail::BoosterCtx(R_ExternalPtrAddr(handle));
auto [s_grad, s_hess] =
xgboost::detail::MakeGradientInterface(ctx, d_grad, d_hess, n_samples, n_targets);
auto [s_grad, s_hess] = xgboost::detail::MakeGradientInterface(
ctx, d_grad, d_hess, xgboost::linalg::kF, n_samples, n_targets);
CHECK_CALL(XGBoosterTrainOneIter(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(dtrain),
asInteger(iter), s_grad.c_str(), s_hess.c_str()));
@@ -435,7 +435,7 @@ XGB_DLL SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evn
std::vector<const char*> vec_sptr;
for (int i = 0; i < len; ++i) {
vec_dmats.push_back(R_ExternalPtrAddr(VECTOR_ELT(dmats, i)));
vec_names.push_back(std::string(CHAR(asChar(VECTOR_ELT(evnames, i)))));
vec_names.emplace_back(CHAR(asChar(VECTOR_ELT(evnames, i))));
}
for (int i = 0; i < len; ++i) {
vec_sptr.push_back(vec_names[i].c_str());