more usage of array interface, fix potential memory leaks of std::string (#9824)
This commit is contained in:
parent
37da66f865
commit
95af5c074b
@ -59,6 +59,32 @@ namespace {
|
||||
return "";
|
||||
}
|
||||
|
||||
[[nodiscard]] std::string MakeArrayInterfaceFromRVector(SEXP R_vec) {
|
||||
const size_t vec_len = Rf_xlength(R_vec);
|
||||
|
||||
// Lambda for type dispatch.
|
||||
auto make_vec = [=](auto const *ptr) {
|
||||
using namespace xgboost; // NOLINT
|
||||
auto v = linalg::MakeVec(ptr, vec_len);
|
||||
return linalg::ArrayInterfaceStr(v);
|
||||
};
|
||||
|
||||
const SEXPTYPE arr_type = TYPEOF(R_vec);
|
||||
switch (arr_type) {
|
||||
case REALSXP:
|
||||
return make_vec(REAL(R_vec));
|
||||
case INTSXP:
|
||||
return make_vec(INTEGER(R_vec));
|
||||
case LGLSXP:
|
||||
return make_vec(LOGICAL(R_vec));
|
||||
default:
|
||||
LOG(FATAL) << "Array or matrix has unsupported type.";
|
||||
}
|
||||
|
||||
LOG(FATAL) << "Not reachable";
|
||||
return "";
|
||||
}
|
||||
|
||||
[[nodiscard]] std::string MakeJsonConfigForArray(SEXP missing, SEXP n_threads, SEXPTYPE arr_type) {
|
||||
using namespace ::xgboost; // NOLINT
|
||||
Json jconfig{Object{}};
|
||||
@ -159,12 +185,15 @@ XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP missing, SEXP n_threads) {
|
||||
SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
|
||||
R_API_BEGIN();
|
||||
|
||||
auto array_str = MakeArrayInterfaceFromRMat(mat);
|
||||
auto config_str = MakeJsonConfigForArray(missing, n_threads, TYPEOF(mat));
|
||||
|
||||
DMatrixHandle handle;
|
||||
CHECK_CALL(XGDMatrixCreateFromDense(array_str.c_str(), config_str.c_str(), &handle));
|
||||
int res_code;
|
||||
{
|
||||
auto array_str = MakeArrayInterfaceFromRMat(mat);
|
||||
auto config_str = MakeJsonConfigForArray(missing, n_threads, TYPEOF(mat));
|
||||
|
||||
res_code = XGDMatrixCreateFromDense(array_str.c_str(), config_str.c_str(), &handle);
|
||||
}
|
||||
CHECK_CALL(res_code);
|
||||
R_SetExternalPtrAddr(ret, handle);
|
||||
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
|
||||
R_API_END();
|
||||
@ -279,23 +308,15 @@ XGB_DLL SEXP XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent) {
|
||||
|
||||
XGB_DLL SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
|
||||
R_API_BEGIN();
|
||||
int len = length(array);
|
||||
const char *name = CHAR(asChar(field));
|
||||
auto ctx = DMatrixCtx(R_ExternalPtrAddr(handle));
|
||||
if (!strcmp("group", name)) {
|
||||
std::vector<unsigned> vec(len);
|
||||
xgboost::common::ParallelFor(len, ctx->Threads(), [&](xgboost::omp_ulong i) {
|
||||
vec[i] = static_cast<unsigned>(INTEGER(array)[i]);
|
||||
});
|
||||
CHECK_CALL(
|
||||
XGDMatrixSetUIntInfo(R_ExternalPtrAddr(handle), CHAR(asChar(field)), BeginPtr(vec), len));
|
||||
} else {
|
||||
std::vector<float> vec(len);
|
||||
xgboost::common::ParallelFor(len, ctx->Threads(),
|
||||
[&](xgboost::omp_ulong i) { vec[i] = REAL(array)[i]; });
|
||||
CHECK_CALL(
|
||||
XGDMatrixSetFloatInfo(R_ExternalPtrAddr(handle), CHAR(asChar(field)), BeginPtr(vec), len));
|
||||
SEXP field_ = PROTECT(Rf_asChar(field));
|
||||
int res_code;
|
||||
{
|
||||
const std::string array_str = MakeArrayInterfaceFromRVector(array);
|
||||
res_code = XGDMatrixSetInfoFromInterface(
|
||||
R_ExternalPtrAddr(handle), CHAR(field_), array_str.c_str());
|
||||
}
|
||||
CHECK_CALL(res_code);
|
||||
UNPROTECT(1);
|
||||
R_API_END();
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user