[R] Fix memory safety issues (#9823)

This commit is contained in:
david-cortes 2023-12-02 06:43:50 +01:00 committed by GitHub
parent e78b46046e
commit 7196c9d95e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -23,6 +23,31 @@
#include "./xgboost_R.h" // Must follow other includes. #include "./xgboost_R.h" // Must follow other includes.
namespace { namespace {
struct ErrorWithUnwind : public std::exception {};
void ThrowExceptionFromRError(void *unused, Rboolean jump) {
if (jump) {
throw ErrorWithUnwind();
}
}
struct PtrToConstChar {
const char *ptr;
};
SEXP WrappedMkChar(void *void_ptr) {
return Rf_mkChar(static_cast<PtrToConstChar*>(void_ptr)->ptr);
}
SEXP SafeMkChar(const char *c_str, SEXP continuation_token) {
PtrToConstChar ptr_struct{c_str};
return R_UnwindProtect(
WrappedMkChar, static_cast<void*>(&ptr_struct),
ThrowExceptionFromRError, nullptr,
continuation_token);
}
[[nodiscard]] std::string MakeArrayInterfaceFromRMat(SEXP R_mat) { [[nodiscard]] std::string MakeArrayInterfaceFromRMat(SEXP R_mat) {
SEXP mat_dims = Rf_getAttrib(R_mat, R_DimSymbol); SEXP mat_dims = Rf_getAttrib(R_mat, R_DimSymbol);
const int *ptr_mat_dims = INTEGER(mat_dims); const int *ptr_mat_dims = INTEGER(mat_dims);
@ -208,8 +233,8 @@ void CreateFromSparse(SEXP indptr, SEXP indices, SEXP data, std::string *indptr_
const int *p_indices = INTEGER(indices); const int *p_indices = INTEGER(indices);
const double *p_data = REAL(data); const double *p_data = REAL(data);
auto nindptr = static_cast<std::size_t>(length(indptr)); auto nindptr = static_cast<std::size_t>(Rf_xlength(indptr));
auto ndata = static_cast<std::size_t>(length(data)); auto ndata = static_cast<std::size_t>(Rf_xlength(data));
CHECK_EQ(ndata, p_indptr[nindptr - 1]); CHECK_EQ(ndata, p_indptr[nindptr - 1]);
xgboost::detail::MakeSparseFromPtr(p_indptr, p_indices, p_data, nindptr, indptr_str, indices_str, xgboost::detail::MakeSparseFromPtr(p_indptr, p_indices, p_data, nindptr, indptr_str, indices_str,
data_str); data_str);
@ -221,24 +246,27 @@ XGB_DLL SEXP XGDMatrixCreateFromCSC_R(SEXP indptr, SEXP indices, SEXP data, SEXP
SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue)); SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
R_API_BEGIN(); R_API_BEGIN();
std::int32_t threads = asInteger(n_threads); std::int32_t threads = asInteger(n_threads);
using xgboost::Integer;
using xgboost::Json;
using xgboost::Object;
std::string sindptr, sindices, sdata;
CreateFromSparse(indptr, indices, data, &sindptr, &sindices, &sdata);
auto nrow = static_cast<std::size_t>(INTEGER(num_row)[0]);
DMatrixHandle handle; DMatrixHandle handle;
Json jconfig{Object{}};
// Construct configuration int res_code;
jconfig["nthread"] = Integer{threads}; {
jconfig["missing"] = xgboost::Number{asReal(missing)}; using xgboost::Integer;
std::string config; using xgboost::Json;
Json::Dump(jconfig, &config); using xgboost::Object;
CHECK_CALL(XGDMatrixCreateFromCSC(sindptr.c_str(), sindices.c_str(), sdata.c_str(), nrow, std::string sindptr, sindices, sdata;
config.c_str(), &handle)); CreateFromSparse(indptr, indices, data, &sindptr, &sindices, &sdata);
auto nrow = static_cast<std::size_t>(INTEGER(num_row)[0]);
Json jconfig{Object{}};
// Construct configuration
jconfig["nthread"] = Integer{threads};
jconfig["missing"] = xgboost::Number{asReal(missing)};
std::string config;
Json::Dump(jconfig, &config);
res_code = XGDMatrixCreateFromCSC(sindptr.c_str(), sindices.c_str(), sdata.c_str(), nrow,
config.c_str(), &handle);
}
CHECK_CALL(res_code);
R_SetExternalPtrAddr(ret, handle); R_SetExternalPtrAddr(ret, handle);
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE); R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
@ -252,24 +280,27 @@ XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data, SEXP
SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue)); SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
R_API_BEGIN(); R_API_BEGIN();
std::int32_t threads = asInteger(n_threads); std::int32_t threads = asInteger(n_threads);
using xgboost::Integer;
using xgboost::Json;
using xgboost::Object;
std::string sindptr, sindices, sdata;
CreateFromSparse(indptr, indices, data, &sindptr, &sindices, &sdata);
auto ncol = static_cast<std::size_t>(INTEGER(num_col)[0]);
DMatrixHandle handle; DMatrixHandle handle;
Json jconfig{Object{}};
// Construct configuration int res_code;
jconfig["nthread"] = Integer{threads}; {
jconfig["missing"] = xgboost::Number{asReal(missing)}; using xgboost::Integer;
std::string config; using xgboost::Json;
Json::Dump(jconfig, &config); using xgboost::Object;
CHECK_CALL(XGDMatrixCreateFromCSR(sindptr.c_str(), sindices.c_str(), sdata.c_str(), ncol,
config.c_str(), &handle)); std::string sindptr, sindices, sdata;
CreateFromSparse(indptr, indices, data, &sindptr, &sindices, &sdata);
auto ncol = static_cast<std::size_t>(INTEGER(num_col)[0]);
Json jconfig{Object{}};
// Construct configuration
jconfig["nthread"] = Integer{threads};
jconfig["missing"] = xgboost::Number{asReal(missing)};
std::string config;
Json::Dump(jconfig, &config);
res_code = XGDMatrixCreateFromCSR(sindptr.c_str(), sindices.c_str(), sdata.c_str(), ncol,
config.c_str(), &handle);
}
R_SetExternalPtrAddr(ret, handle); R_SetExternalPtrAddr(ret, handle);
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE); R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
R_API_END(); R_API_END();
@ -280,16 +311,22 @@ XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data, SEXP
XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) { XGB_DLL SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) {
SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue)); SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
R_API_BEGIN(); R_API_BEGIN();
int len = length(idxset); R_xlen_t len = Rf_xlength(idxset);
std::vector<int> idxvec(len); const int *idxset_ = INTEGER(idxset);
for (int i = 0; i < len; ++i) {
idxvec[i] = INTEGER(idxset)[i] - 1;
}
DMatrixHandle res; DMatrixHandle res;
CHECK_CALL(XGDMatrixSliceDMatrixEx(R_ExternalPtrAddr(handle),
BeginPtr(idxvec), len, int res_code;
&res, {
0)); std::vector<int> idxvec(len);
for (R_xlen_t i = 0; i < len; ++i) {
idxvec[i] = idxset_[i] - 1;
}
res_code = XGDMatrixSliceDMatrixEx(R_ExternalPtrAddr(handle),
BeginPtr(idxvec), len,
&res,
0);
}
CHECK_CALL(res_code);
R_SetExternalPtrAddr(ret, res); R_SetExternalPtrAddr(ret, res);
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE); R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
R_API_END(); R_API_END();
@ -325,18 +362,29 @@ XGB_DLL SEXP XGDMatrixSetStrFeatureInfo_R(SEXP handle, SEXP field, SEXP array) {
R_API_BEGIN(); R_API_BEGIN();
size_t len{0}; size_t len{0};
if (!isNull(array)) { if (!isNull(array)) {
len = length(array); len = Rf_xlength(array);
} }
const char *name = CHAR(asChar(field)); SEXP str_info_holder = PROTECT(Rf_allocVector(VECSXP, len));
std::vector<std::string> str_info;
for (size_t i = 0; i < len; ++i) { for (size_t i = 0; i < len; ++i) {
str_info.emplace_back(CHAR(asChar(VECTOR_ELT(array, i)))); SET_VECTOR_ELT(str_info_holder, i, Rf_asChar(VECTOR_ELT(array, i)));
} }
std::vector<char const*> vec(len);
std::transform(str_info.cbegin(), str_info.cend(), vec.begin(), SEXP field_ = PROTECT(Rf_asChar(field));
[](std::string const &str) { return str.c_str(); }); const char *name = CHAR(field_);
CHECK_CALL(XGDMatrixSetStrFeatureInfo(R_ExternalPtrAddr(handle), name, vec.data(), len)); int res_code;
{
std::vector<std::string> str_info;
for (size_t i = 0; i < len; ++i) {
str_info.emplace_back(CHAR(VECTOR_ELT(str_info_holder, i)));
}
std::vector<char const*> vec(len);
std::transform(str_info.cbegin(), str_info.cend(), vec.begin(),
[](std::string const &str) { return str.c_str(); });
res_code = XGDMatrixSetStrFeatureInfo(R_ExternalPtrAddr(handle), name, vec.data(), len);
}
CHECK_CALL(res_code);
UNPROTECT(2);
R_API_END(); R_API_END();
return R_NilValue; return R_NilValue;
} }
@ -369,8 +417,9 @@ XGB_DLL SEXP XGDMatrixGetInfo_R(SEXP handle, SEXP field) {
const float *res; const float *res;
CHECK_CALL(XGDMatrixGetFloatInfo(R_ExternalPtrAddr(handle), CHAR(asChar(field)), &olen, &res)); CHECK_CALL(XGDMatrixGetFloatInfo(R_ExternalPtrAddr(handle), CHAR(asChar(field)), &olen, &res));
ret = PROTECT(allocVector(REALSXP, olen)); ret = PROTECT(allocVector(REALSXP, olen));
double *ret_ = REAL(ret);
for (size_t i = 0; i < olen; ++i) { for (size_t i = 0; i < olen; ++i) {
REAL(ret)[i] = res[i]; ret_[i] = res[i];
} }
R_API_END(); R_API_END();
UNPROTECT(1); UNPROTECT(1);
@ -403,13 +452,18 @@ void _BoosterFinalizer(SEXP ext) {
XGB_DLL SEXP XGBoosterCreate_R(SEXP dmats) { XGB_DLL SEXP XGBoosterCreate_R(SEXP dmats) {
SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue)); SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
R_API_BEGIN(); R_API_BEGIN();
int len = length(dmats); R_xlen_t len = Rf_xlength(dmats);
std::vector<void*> dvec;
for (int i = 0; i < len; ++i) {
dvec.push_back(R_ExternalPtrAddr(VECTOR_ELT(dmats, i)));
}
BoosterHandle handle; BoosterHandle handle;
CHECK_CALL(XGBoosterCreate(BeginPtr(dvec), dvec.size(), &handle));
int res_code;
{
std::vector<void*> dvec;
for (R_xlen_t i = 0; i < len; ++i) {
dvec.push_back(R_ExternalPtrAddr(VECTOR_ELT(dmats, i)));
}
res_code = XGBoosterCreate(BeginPtr(dvec), dvec.size(), &handle);
}
CHECK_CALL(res_code);
R_SetExternalPtrAddr(ret, handle); R_SetExternalPtrAddr(ret, handle);
R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE); R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
R_API_END(); R_API_END();
@ -419,13 +473,18 @@ XGB_DLL SEXP XGBoosterCreate_R(SEXP dmats) {
XGB_DLL SEXP XGBoosterCreateInEmptyObj_R(SEXP dmats, SEXP R_handle) { XGB_DLL SEXP XGBoosterCreateInEmptyObj_R(SEXP dmats, SEXP R_handle) {
R_API_BEGIN(); R_API_BEGIN();
int len = length(dmats); R_xlen_t len = Rf_xlength(dmats);
std::vector<void*> dvec;
for (int i = 0; i < len; ++i) {
dvec.push_back(R_ExternalPtrAddr(VECTOR_ELT(dmats, i)));
}
BoosterHandle handle; BoosterHandle handle;
CHECK_CALL(XGBoosterCreate(BeginPtr(dvec), dvec.size(), &handle));
int res_code;
{
std::vector<void*> dvec;
for (R_xlen_t i = 0; i < len; ++i) {
dvec.push_back(R_ExternalPtrAddr(VECTOR_ELT(dmats, i)));
}
res_code = XGBoosterCreate(BeginPtr(dvec), dvec.size(), &handle);
}
CHECK_CALL(res_code);
R_SetExternalPtrAddr(R_handle, handle); R_SetExternalPtrAddr(R_handle, handle);
R_RegisterCFinalizerEx(R_handle, _BoosterFinalizer, TRUE); R_RegisterCFinalizerEx(R_handle, _BoosterFinalizer, TRUE);
R_API_END(); R_API_END();
@ -434,9 +493,12 @@ XGB_DLL SEXP XGBoosterCreateInEmptyObj_R(SEXP dmats, SEXP R_handle) {
XGB_DLL SEXP XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val) { XGB_DLL SEXP XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val) {
R_API_BEGIN(); R_API_BEGIN();
SEXP name_ = PROTECT(Rf_asChar(name));
SEXP val_ = PROTECT(Rf_asChar(val));
CHECK_CALL(XGBoosterSetParam(R_ExternalPtrAddr(handle), CHECK_CALL(XGBoosterSetParam(R_ExternalPtrAddr(handle),
CHAR(asChar(name)), CHAR(name_),
CHAR(asChar(val)))); CHAR(val_)));
UNPROTECT(2);
R_API_END(); R_API_END();
return R_NilValue; return R_NilValue;
} }
@ -452,7 +514,7 @@ XGB_DLL SEXP XGBoosterUpdateOneIter_R(SEXP handle, SEXP iter, SEXP dtrain) {
XGB_DLL SEXP XGBoosterTrainOneIter_R(SEXP handle, SEXP dtrain, SEXP iter, SEXP grad, SEXP hess) { XGB_DLL SEXP XGBoosterTrainOneIter_R(SEXP handle, SEXP dtrain, SEXP iter, SEXP grad, SEXP hess) {
R_API_BEGIN(); R_API_BEGIN();
CHECK_EQ(length(grad), length(hess)) << "gradient and hess must have same length."; CHECK_EQ(Rf_xlength(grad), Rf_xlength(hess)) << "gradient and hess must have same length.";
SEXP gdim = getAttrib(grad, R_DimSymbol); SEXP gdim = getAttrib(grad, R_DimSymbol);
auto n_samples = static_cast<std::size_t>(INTEGER(gdim)[0]); auto n_samples = static_cast<std::size_t>(INTEGER(gdim)[0]);
auto n_targets = static_cast<std::size_t>(INTEGER(gdim)[1]); auto n_targets = static_cast<std::size_t>(INTEGER(gdim)[1]);
@ -463,11 +525,15 @@ XGB_DLL SEXP XGBoosterTrainOneIter_R(SEXP handle, SEXP dtrain, SEXP iter, SEXP g
double const *d_grad = REAL(grad); double const *d_grad = REAL(grad);
double const *d_hess = REAL(hess); double const *d_hess = REAL(hess);
auto ctx = xgboost::detail::BoosterCtx(R_ExternalPtrAddr(handle)); int res_code;
auto [s_grad, s_hess] = xgboost::detail::MakeGradientInterface( {
ctx, d_grad, d_hess, xgboost::linalg::kF, n_samples, n_targets); auto ctx = xgboost::detail::BoosterCtx(R_ExternalPtrAddr(handle));
CHECK_CALL(XGBoosterTrainOneIter(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(dtrain), auto [s_grad, s_hess] = xgboost::detail::MakeGradientInterface(
asInteger(iter), s_grad.c_str(), s_hess.c_str())); ctx, d_grad, d_hess, xgboost::linalg::kF, n_samples, n_targets);
res_code = XGBoosterTrainOneIter(R_ExternalPtrAddr(handle), R_ExternalPtrAddr(dtrain),
asInteger(iter), s_grad.c_str(), s_hess.c_str());
}
CHECK_CALL(res_code);
R_API_END(); R_API_END();
return R_NilValue; return R_NilValue;
@ -476,24 +542,34 @@ XGB_DLL SEXP XGBoosterTrainOneIter_R(SEXP handle, SEXP dtrain, SEXP iter, SEXP g
XGB_DLL SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames) { XGB_DLL SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames) {
const char *ret; const char *ret;
R_API_BEGIN(); R_API_BEGIN();
CHECK_EQ(length(dmats), length(evnames)) CHECK_EQ(Rf_xlength(dmats), Rf_xlength(evnames))
<< "dmats and evnams must have same length"; << "dmats and evnams must have same length";
int len = length(dmats); R_xlen_t len = Rf_xlength(dmats);
std::vector<void*> vec_dmats; SEXP evnames_lst = PROTECT(Rf_allocVector(VECSXP, len));
std::vector<std::string> vec_names; for (R_xlen_t i = 0; i < len; i++) {
std::vector<const char*> vec_sptr; SET_VECTOR_ELT(evnames_lst, i, Rf_asChar(VECTOR_ELT(evnames, i)));
for (int i = 0; i < len; ++i) {
vec_dmats.push_back(R_ExternalPtrAddr(VECTOR_ELT(dmats, i)));
vec_names.emplace_back(CHAR(asChar(VECTOR_ELT(evnames, i))));
} }
for (int i = 0; i < len; ++i) {
vec_sptr.push_back(vec_names[i].c_str()); int res_code;
{
std::vector<void*> vec_dmats;
std::vector<std::string> vec_names;
std::vector<const char*> vec_sptr;
for (R_xlen_t i = 0; i < len; ++i) {
vec_dmats.push_back(R_ExternalPtrAddr(VECTOR_ELT(dmats, i)));
vec_names.emplace_back(CHAR(VECTOR_ELT(evnames_lst, i)));
}
for (R_xlen_t i = 0; i < len; ++i) {
vec_sptr.push_back(vec_names[i].c_str());
}
res_code = XGBoosterEvalOneIter(R_ExternalPtrAddr(handle),
asInteger(iter),
BeginPtr(vec_dmats),
BeginPtr(vec_sptr),
len, &ret);
} }
CHECK_CALL(XGBoosterEvalOneIter(R_ExternalPtrAddr(handle), CHECK_CALL(res_code);
asInteger(iter), UNPROTECT(1);
BeginPtr(vec_dmats),
BeginPtr(vec_sptr),
len, &ret));
R_API_END(); R_API_END();
return mkString(ret); return mkString(ret);
} }
@ -501,10 +577,11 @@ XGB_DLL SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evn
XGB_DLL SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_config) { XGB_DLL SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_config) {
SEXP r_out_shape; SEXP r_out_shape;
SEXP r_out_result; SEXP r_out_result;
SEXP r_out; SEXP r_out = PROTECT(allocVector(VECSXP, 2));
SEXP json_config_ = PROTECT(Rf_asChar(json_config));
R_API_BEGIN(); R_API_BEGIN();
char const *c_json_config = CHAR(asChar(json_config)); char const *c_json_config = CHAR(json_config_);
bst_ulong out_dim; bst_ulong out_dim;
bst_ulong const *out_shape; bst_ulong const *out_shape;
@ -515,23 +592,23 @@ XGB_DLL SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_con
r_out_shape = PROTECT(allocVector(INTSXP, out_dim)); r_out_shape = PROTECT(allocVector(INTSXP, out_dim));
size_t len = 1; size_t len = 1;
int *r_out_shape_ = INTEGER(r_out_shape);
for (size_t i = 0; i < out_dim; ++i) { for (size_t i = 0; i < out_dim; ++i) {
INTEGER(r_out_shape)[i] = out_shape[i]; r_out_shape_[i] = out_shape[i];
len *= out_shape[i]; len *= out_shape[i];
} }
r_out_result = PROTECT(allocVector(REALSXP, len)); r_out_result = PROTECT(allocVector(REALSXP, len));
auto ctx = xgboost::detail::BoosterCtx(R_ExternalPtrAddr(handle)); auto ctx = xgboost::detail::BoosterCtx(R_ExternalPtrAddr(handle));
double *r_out_result_ = REAL(r_out_result);
xgboost::common::ParallelFor(len, ctx->Threads(), [&](xgboost::omp_ulong i) { xgboost::common::ParallelFor(len, ctx->Threads(), [&](xgboost::omp_ulong i) {
REAL(r_out_result)[i] = out_result[i]; r_out_result_[i] = out_result[i];
}); });
r_out = PROTECT(allocVector(VECSXP, 2));
SET_VECTOR_ELT(r_out, 0, r_out_shape); SET_VECTOR_ELT(r_out, 0, r_out_shape);
SET_VECTOR_ELT(r_out, 1, r_out_result); SET_VECTOR_ELT(r_out, 1, r_out_result);
R_API_END(); R_API_END();
UNPROTECT(3); UNPROTECT(4);
return r_out; return r_out;
} }
@ -554,7 +631,7 @@ XGB_DLL SEXP XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) {
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(XGBoosterLoadModelFromBuffer(R_ExternalPtrAddr(handle), CHECK_CALL(XGBoosterLoadModelFromBuffer(R_ExternalPtrAddr(handle),
RAW(raw), RAW(raw),
length(raw))); Rf_xlength(raw)));
R_API_END(); R_API_END();
return R_NilValue; return R_NilValue;
} }
@ -612,45 +689,54 @@ XGB_DLL SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw) {
R_API_BEGIN(); R_API_BEGIN();
CHECK_CALL(XGBoosterUnserializeFromBuffer(R_ExternalPtrAddr(handle), CHECK_CALL(XGBoosterUnserializeFromBuffer(R_ExternalPtrAddr(handle),
RAW(raw), RAW(raw),
length(raw))); Rf_xlength(raw)));
R_API_END(); R_API_END();
return R_NilValue; return R_NilValue;
} }
XGB_DLL SEXP XGBoosterDumpModel_R(SEXP handle, SEXP fmap, SEXP with_stats, SEXP dump_format) { XGB_DLL SEXP XGBoosterDumpModel_R(SEXP handle, SEXP fmap, SEXP with_stats, SEXP dump_format) {
SEXP out; SEXP out;
SEXP continuation_token = PROTECT(R_MakeUnwindCont());
SEXP dump_format_ = PROTECT(Rf_asChar(dump_format));
SEXP fmap_ = PROTECT(Rf_asChar(fmap));
R_API_BEGIN(); R_API_BEGIN();
bst_ulong olen; bst_ulong olen;
const char **res; const char **res;
const char *fmt = CHAR(asChar(dump_format)); const char *fmt = CHAR(dump_format_);
CHECK_CALL(XGBoosterDumpModelEx(R_ExternalPtrAddr(handle), CHECK_CALL(XGBoosterDumpModelEx(R_ExternalPtrAddr(handle),
CHAR(asChar(fmap)), CHAR(fmap_),
asInteger(with_stats), asInteger(with_stats),
fmt, fmt,
&olen, &res)); &olen, &res));
out = PROTECT(allocVector(STRSXP, olen)); out = PROTECT(allocVector(STRSXP, olen));
if (!strcmp("json", fmt)) { try {
std::stringstream stream; if (!strcmp("json", fmt)) {
stream << "[\n"; std::stringstream stream;
for (size_t i = 0; i < olen; ++i) { stream << "[\n";
stream << res[i]; for (size_t i = 0; i < olen; ++i) {
if (i < olen - 1) { stream << res[i];
stream << ",\n"; if (i < olen - 1) {
} else { stream << ",\n";
stream << "\n"; } else {
stream << "\n";
}
}
stream << "]";
const std::string temp_str = stream.str();
SET_STRING_ELT(out, 0, SafeMkChar(temp_str.c_str(), continuation_token));
} else {
for (size_t i = 0; i < olen; ++i) {
std::stringstream stream;
stream << "booster[" << i <<"]\n" << res[i];
const std::string temp_str = stream.str();
SET_STRING_ELT(out, i, SafeMkChar(temp_str.c_str(), continuation_token));
} }
} }
stream << "]"; } catch (ErrorWithUnwind &e) {
SET_STRING_ELT(out, 0, mkChar(stream.str().c_str())); R_ContinueUnwind(continuation_token);
} else {
for (size_t i = 0; i < olen; ++i) {
std::stringstream stream;
stream << "booster[" << i <<"]\n" << res[i];
SET_STRING_ELT(out, i, mkChar(stream.str().c_str()));
}
} }
R_API_END(); R_API_END();
UNPROTECT(1); UNPROTECT(4);
return out; return out;
} }
@ -676,9 +762,19 @@ XGB_DLL SEXP XGBoosterGetAttr_R(SEXP handle, SEXP name) {
XGB_DLL SEXP XGBoosterSetAttr_R(SEXP handle, SEXP name, SEXP val) { XGB_DLL SEXP XGBoosterSetAttr_R(SEXP handle, SEXP name, SEXP val) {
R_API_BEGIN(); R_API_BEGIN();
const char *v = isNull(val) ? nullptr : CHAR(asChar(val)); const char *v = nullptr;
SEXP name_ = PROTECT(Rf_asChar(name));
SEXP val_;
int n_protected = 1;
if (!Rf_isNull(val)) {
val_ = PROTECT(Rf_asChar(val));
n_protected++;
v = CHAR(val_);
}
CHECK_CALL(XGBoosterSetAttr(R_ExternalPtrAddr(handle), CHECK_CALL(XGBoosterSetAttr(R_ExternalPtrAddr(handle),
CHAR(asChar(name)), v)); CHAR(name_), v));
UNPROTECT(n_protected);
R_API_END(); R_API_END();
return R_NilValue; return R_NilValue;
} }
@ -707,7 +803,7 @@ XGB_DLL SEXP XGBoosterFeatureScore_R(SEXP handle, SEXP json_config) {
SEXP out_features_sexp; SEXP out_features_sexp;
SEXP out_scores_sexp; SEXP out_scores_sexp;
SEXP out_shape_sexp; SEXP out_shape_sexp;
SEXP r_out; SEXP r_out = PROTECT(allocVector(VECSXP, 3));
R_API_BEGIN(); R_API_BEGIN();
char const *c_json_config = CHAR(asChar(json_config)); char const *c_json_config = CHAR(asChar(json_config));
@ -723,23 +819,24 @@ XGB_DLL SEXP XGBoosterFeatureScore_R(SEXP handle, SEXP json_config) {
&out_dim, &out_shape, &out_scores)); &out_dim, &out_shape, &out_scores));
out_shape_sexp = PROTECT(allocVector(INTSXP, out_dim)); out_shape_sexp = PROTECT(allocVector(INTSXP, out_dim));
size_t len = 1; size_t len = 1;
int *out_shape_sexp_ = INTEGER(out_shape_sexp);
for (size_t i = 0; i < out_dim; ++i) { for (size_t i = 0; i < out_dim; ++i) {
INTEGER(out_shape_sexp)[i] = out_shape[i]; out_shape_sexp_[i] = out_shape[i];
len *= out_shape[i]; len *= out_shape[i];
} }
out_scores_sexp = PROTECT(allocVector(REALSXP, len));
auto ctx = xgboost::detail::BoosterCtx(R_ExternalPtrAddr(handle));
xgboost::common::ParallelFor(len, ctx->Threads(), [&](xgboost::omp_ulong i) {
REAL(out_scores_sexp)[i] = out_scores[i];
});
out_features_sexp = PROTECT(allocVector(STRSXP, out_n_features)); out_features_sexp = PROTECT(allocVector(STRSXP, out_n_features));
for (size_t i = 0; i < out_n_features; ++i) { for (size_t i = 0; i < out_n_features; ++i) {
SET_STRING_ELT(out_features_sexp, i, mkChar(out_features[i])); SET_STRING_ELT(out_features_sexp, i, mkChar(out_features[i]));
} }
r_out = PROTECT(allocVector(VECSXP, 3)); out_scores_sexp = PROTECT(allocVector(REALSXP, len));
auto ctx = xgboost::detail::BoosterCtx(R_ExternalPtrAddr(handle));
double *out_scores_sexp_ = REAL(out_scores_sexp);
xgboost::common::ParallelFor(len, ctx->Threads(), [&](xgboost::omp_ulong i) {
out_scores_sexp_[i] = out_scores[i];
});
SET_VECTOR_ELT(r_out, 0, out_features_sexp); SET_VECTOR_ELT(r_out, 0, out_features_sexp);
SET_VECTOR_ELT(r_out, 1, out_shape_sexp); SET_VECTOR_ELT(r_out, 1, out_shape_sexp);
SET_VECTOR_ELT(r_out, 2, out_scores_sexp); SET_VECTOR_ELT(r_out, 2, out_scores_sexp);