[R] Implement feature info for DMatrix. (#8048)
This commit is contained in:
@@ -42,10 +42,12 @@ extern SEXP XGDMatrixCreateFromCSR_R(SEXP, SEXP, SEXP, SEXP, SEXP);
|
||||
extern SEXP XGDMatrixCreateFromFile_R(SEXP, SEXP);
|
||||
extern SEXP XGDMatrixCreateFromMat_R(SEXP, SEXP, SEXP);
|
||||
extern SEXP XGDMatrixGetInfo_R(SEXP, SEXP);
|
||||
extern SEXP XGDMatrixGetStrFeatureInfo_R(SEXP, SEXP);
|
||||
extern SEXP XGDMatrixNumCol_R(SEXP);
|
||||
extern SEXP XGDMatrixNumRow_R(SEXP);
|
||||
extern SEXP XGDMatrixSaveBinary_R(SEXP, SEXP, SEXP);
|
||||
extern SEXP XGDMatrixSetInfo_R(SEXP, SEXP, SEXP);
|
||||
extern SEXP XGDMatrixSetStrFeatureInfo_R(SEXP, SEXP, SEXP);
|
||||
extern SEXP XGDMatrixSliceDMatrix_R(SEXP, SEXP);
|
||||
extern SEXP XGBSetGlobalConfig_R(SEXP);
|
||||
extern SEXP XGBGetGlobalConfig_R();
|
||||
@@ -78,10 +80,12 @@ static const R_CallMethodDef CallEntries[] = {
|
||||
{"XGDMatrixCreateFromFile_R", (DL_FUNC) &XGDMatrixCreateFromFile_R, 2},
|
||||
{"XGDMatrixCreateFromMat_R", (DL_FUNC) &XGDMatrixCreateFromMat_R, 3},
|
||||
{"XGDMatrixGetInfo_R", (DL_FUNC) &XGDMatrixGetInfo_R, 2},
|
||||
{"XGDMatrixGetStrFeatureInfo_R", (DL_FUNC) &XGDMatrixGetStrFeatureInfo_R, 2},
|
||||
{"XGDMatrixNumCol_R", (DL_FUNC) &XGDMatrixNumCol_R, 1},
|
||||
{"XGDMatrixNumRow_R", (DL_FUNC) &XGDMatrixNumRow_R, 1},
|
||||
{"XGDMatrixSaveBinary_R", (DL_FUNC) &XGDMatrixSaveBinary_R, 3},
|
||||
{"XGDMatrixSetInfo_R", (DL_FUNC) &XGDMatrixSetInfo_R, 3},
|
||||
{"XGDMatrixSetStrFeatureInfo_R", (DL_FUNC) &XGDMatrixSetStrFeatureInfo_R, 3},
|
||||
{"XGDMatrixSliceDMatrix_R", (DL_FUNC) &XGDMatrixSliceDMatrix_R, 2},
|
||||
{"XGBSetGlobalConfig_R", (DL_FUNC) &XGBSetGlobalConfig_R, 1},
|
||||
{"XGBGetGlobalConfig_R", (DL_FUNC) &XGBGetGlobalConfig_R, 0},
|
||||
|
||||
@@ -249,15 +249,53 @@ XGB_DLL SEXP XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
XGB_DLL SEXP XGDMatrixSetStrFeatureInfo_R(SEXP handle, SEXP field, SEXP array) {
|
||||
R_API_BEGIN();
|
||||
size_t len{0};
|
||||
if (!isNull(array)) {
|
||||
len = length(array);
|
||||
}
|
||||
|
||||
const char *name = CHAR(asChar(field));
|
||||
std::vector<std::string> str_info;
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
str_info.emplace_back(CHAR(asChar(VECTOR_ELT(array, i))));
|
||||
}
|
||||
std::vector<char const*> vec(len);
|
||||
std::transform(str_info.cbegin(), str_info.cend(), vec.begin(),
|
||||
[](auto const &str) { return str.c_str(); });
|
||||
CHECK_CALL(XGDMatrixSetStrFeatureInfo(R_ExternalPtrAddr(handle), name, vec.data(), len));
|
||||
R_API_END();
|
||||
return R_NilValue;
|
||||
}
|
||||
|
||||
XGB_DLL SEXP XGDMatrixGetStrFeatureInfo_R(SEXP handle, SEXP field) {
|
||||
SEXP ret;
|
||||
R_API_BEGIN();
|
||||
char const **out_features{nullptr};
|
||||
bst_ulong len{0};
|
||||
const char *name = CHAR(asChar(field));
|
||||
XGDMatrixGetStrFeatureInfo(R_ExternalPtrAddr(handle), name, &len, &out_features);
|
||||
|
||||
if (len > 0) {
|
||||
ret = PROTECT(allocVector(STRSXP, len));
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
SET_STRING_ELT(ret, i, mkChar(out_features[i]));
|
||||
}
|
||||
} else {
|
||||
ret = PROTECT(R_NilValue);
|
||||
}
|
||||
R_API_END();
|
||||
UNPROTECT(1);
|
||||
return ret;
|
||||
}
|
||||
|
||||
XGB_DLL SEXP XGDMatrixGetInfo_R(SEXP handle, SEXP field) {
|
||||
SEXP ret;
|
||||
R_API_BEGIN();
|
||||
bst_ulong olen;
|
||||
const float *res;
|
||||
CHECK_CALL(XGDMatrixGetFloatInfo(R_ExternalPtrAddr(handle),
|
||||
CHAR(asChar(field)),
|
||||
&olen,
|
||||
&res));
|
||||
CHECK_CALL(XGDMatrixGetFloatInfo(R_ExternalPtrAddr(handle), CHAR(asChar(field)), &olen, &res));
|
||||
ret = PROTECT(allocVector(REALSXP, olen));
|
||||
for (size_t i = 0; i < olen; ++i) {
|
||||
REAL(ret)[i] = res[i];
|
||||
|
||||
Reference in New Issue
Block a user