more usage of array interface, fix potential memory leaks of std::string (#9824)
This commit is contained in:
parent
37da66f865
commit
95af5c074b
@ -59,6 +59,32 @@ namespace {
|
|||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] std::string MakeArrayInterfaceFromRVector(SEXP R_vec) {
|
||||||
|
const size_t vec_len = Rf_xlength(R_vec);
|
||||||
|
|
||||||
|
// Lambda for type dispatch.
|
||||||
|
auto make_vec = [=](auto const *ptr) {
|
||||||
|
using namespace xgboost; // NOLINT
|
||||||
|
auto v = linalg::MakeVec(ptr, vec_len);
|
||||||
|
return linalg::ArrayInterfaceStr(v);
|
||||||
|
};
|
||||||
|
|
||||||
|
const SEXPTYPE arr_type = TYPEOF(R_vec);
|
||||||
|
switch (arr_type) {
|
||||||
|
case REALSXP:
|
||||||
|
return make_vec(REAL(R_vec));
|
||||||
|
case INTSXP:
|
||||||
|
return make_vec(INTEGER(R_vec));
|
||||||
|
case LGLSXP:
|
||||||
|
return make_vec(LOGICAL(R_vec));
|
||||||
|
default:
|
||||||
|
LOG(FATAL) << "Array or matrix has unsupported type.";
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG(FATAL) << "Not reachable";
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
[[nodiscard]] std::string MakeJsonConfigForArray(SEXP missing, SEXP n_threads, SEXPTYPE arr_type) {
|
[[nodiscard]] std::string MakeJsonConfigForArray(SEXP missing, SEXP n_threads, SEXPTYPE arr_type) {
|
||||||
using namespace ::xgboost; // NOLINT
|
using namespace ::xgboost; // NOLINT
|
||||||
Json jconfig{Object{}};
|
Json jconfig{Object{}};
|
||||||
@ -159,12 +185,15 @@ XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP missing, SEXP n_threads) {
|
|||||||
SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
|
SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
|
||||||
R_API_BEGIN();
|
R_API_BEGIN();
|
||||||
|
|
||||||
auto array_str = MakeArrayInterfaceFromRMat(mat);
|
|
||||||
auto config_str = MakeJsonConfigForArray(missing, n_threads, TYPEOF(mat));
|
|
||||||
|
|
||||||
DMatrixHandle handle;
|
DMatrixHandle handle;
|
||||||
CHECK_CALL(XGDMatrixCreateFromDense(array_str.c_str(), config_str.c_str(), &handle));
|
int res_code;
|
||||||
|
{
|
||||||
|
auto array_str = MakeArrayInterfaceFromRMat(mat);
|
||||||
|
auto config_str = MakeJsonConfigForArray(missing, n_threads, TYPEOF(mat));
|
||||||
|
|
||||||
|
res_code = XGDMatrixCreateFromDense(array_str.c_str(), config_str.c_str(), &handle);
|
||||||
|
}
|
||||||
|
CHECK_CALL(res_code);
|
||||||
R_SetExternalPtrAddr(ret, handle);
|
R_SetExternalPtrAddr(ret, handle);
|
||||||
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
|
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
|
||||||
R_API_END();
|
R_API_END();
|
||||||
@ -279,23 +308,15 @@ XGB_DLL SEXP XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent) {
|
|||||||
|
|
||||||
XGB_DLL SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
|
XGB_DLL SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
|
||||||
R_API_BEGIN();
|
R_API_BEGIN();
|
||||||
int len = length(array);
|
SEXP field_ = PROTECT(Rf_asChar(field));
|
||||||
const char *name = CHAR(asChar(field));
|
int res_code;
|
||||||
auto ctx = DMatrixCtx(R_ExternalPtrAddr(handle));
|
{
|
||||||
if (!strcmp("group", name)) {
|
const std::string array_str = MakeArrayInterfaceFromRVector(array);
|
||||||
std::vector<unsigned> vec(len);
|
res_code = XGDMatrixSetInfoFromInterface(
|
||||||
xgboost::common::ParallelFor(len, ctx->Threads(), [&](xgboost::omp_ulong i) {
|
R_ExternalPtrAddr(handle), CHAR(field_), array_str.c_str());
|
||||||
vec[i] = static_cast<unsigned>(INTEGER(array)[i]);
|
|
||||||
});
|
|
||||||
CHECK_CALL(
|
|
||||||
XGDMatrixSetUIntInfo(R_ExternalPtrAddr(handle), CHAR(asChar(field)), BeginPtr(vec), len));
|
|
||||||
} else {
|
|
||||||
std::vector<float> vec(len);
|
|
||||||
xgboost::common::ParallelFor(len, ctx->Threads(),
|
|
||||||
[&](xgboost::omp_ulong i) { vec[i] = REAL(array)[i]; });
|
|
||||||
CHECK_CALL(
|
|
||||||
XGDMatrixSetFloatInfo(R_ExternalPtrAddr(handle), CHAR(asChar(field)), BeginPtr(vec), len));
|
|
||||||
}
|
}
|
||||||
|
CHECK_CALL(res_code);
|
||||||
|
UNPROTECT(1);
|
||||||
R_API_END();
|
R_API_END();
|
||||||
return R_NilValue;
|
return R_NilValue;
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user