[R] On-demand serialization + standardization of attributes (#9924)

---------

Co-authored-by: Jiaming Yuan <jm.yuan@outlook.com>
This commit is contained in:
david-cortes
2024-01-10 22:08:42 +01:00
committed by GitHub
parent 01c4711556
commit d3a8d284ab
64 changed files with 1773 additions and 1281 deletions

View File

@@ -15,9 +15,16 @@ Check these declarations against the C/Fortran source code.
*/
/* .Call calls */
extern void XGBInitializeAltrepClass_R(DllInfo *info);
extern SEXP XGDuplicate_R(SEXP);
extern SEXP XGPointerEqComparison_R(SEXP, SEXP);
extern SEXP XGBoosterTrainOneIter_R(SEXP, SEXP, SEXP, SEXP, SEXP);
extern SEXP XGBoosterCreate_R(SEXP);
extern SEXP XGBoosterCreateInEmptyObj_R(SEXP, SEXP);
extern SEXP XGBoosterCopyInfoFromDMatrix_R(SEXP, SEXP);
extern SEXP XGBoosterSetStrFeatureInfo_R(SEXP, SEXP, SEXP);
extern SEXP XGBoosterGetStrFeatureInfo_R(SEXP, SEXP);
extern SEXP XGBoosterBoostedRounds_R(SEXP);
extern SEXP XGBoosterGetNumFeature_R(SEXP);
extern SEXP XGBoosterDumpModel_R(SEXP, SEXP, SEXP, SEXP);
extern SEXP XGBoosterEvalOneIter_R(SEXP, SEXP, SEXP, SEXP);
extern SEXP XGBoosterGetAttrNames_R(SEXP);
@@ -57,9 +64,15 @@ extern SEXP XGBGetGlobalConfig_R(void);
extern SEXP XGBoosterFeatureScore_R(SEXP, SEXP);
static const R_CallMethodDef CallEntries[] = {
{"XGDuplicate_R", (DL_FUNC) &XGDuplicate_R, 1},
{"XGPointerEqComparison_R", (DL_FUNC) &XGPointerEqComparison_R, 2},
{"XGBoosterTrainOneIter_R", (DL_FUNC) &XGBoosterTrainOneIter_R, 5},
{"XGBoosterCreate_R", (DL_FUNC) &XGBoosterCreate_R, 1},
{"XGBoosterCreateInEmptyObj_R", (DL_FUNC) &XGBoosterCreateInEmptyObj_R, 2},
{"XGBoosterCopyInfoFromDMatrix_R", (DL_FUNC) &XGBoosterCopyInfoFromDMatrix_R, 2},
{"XGBoosterSetStrFeatureInfo_R",(DL_FUNC) &XGBoosterSetStrFeatureInfo_R,3}, // NOLINT
{"XGBoosterGetStrFeatureInfo_R",(DL_FUNC) &XGBoosterGetStrFeatureInfo_R,2}, // NOLINT
{"XGBoosterBoostedRounds_R", (DL_FUNC) &XGBoosterBoostedRounds_R, 1},
{"XGBoosterGetNumFeature_R", (DL_FUNC) &XGBoosterGetNumFeature_R, 1},
{"XGBoosterDumpModel_R", (DL_FUNC) &XGBoosterDumpModel_R, 4},
{"XGBoosterEvalOneIter_R", (DL_FUNC) &XGBoosterEvalOneIter_R, 4},
{"XGBoosterGetAttrNames_R", (DL_FUNC) &XGBoosterGetAttrNames_R, 1},
@@ -106,4 +119,5 @@ __declspec(dllexport)
void attribute_visible R_init_xgboost(DllInfo *dll) {
R_registerRoutines(dll, NULL, CallEntries, NULL, NULL);
R_useDynamicSymbols(dll, FALSE);
XGBInitializeAltrepClass_R(dll);
}

View File

@@ -260,16 +260,18 @@ char cpp_ex_msg[512];
using dmlc::BeginPtr;
XGB_DLL SEXP XGCheckNullPtr_R(SEXP handle) {
return ScalarLogical(R_ExternalPtrAddr(handle) == NULL);
return Rf_ScalarLogical(R_ExternalPtrAddr(handle) == nullptr);
}
XGB_DLL void _DMatrixFinalizer(SEXP ext) {
namespace {
void _DMatrixFinalizer(SEXP ext) {
R_API_BEGIN();
if (R_ExternalPtrAddr(ext) == NULL) return;
CHECK_CALL(XGDMatrixFree(R_ExternalPtrAddr(ext)));
R_ClearExternalPtr(ext);
R_API_END();
}
} /* namespace */
XGB_DLL SEXP XGBSetGlobalConfig_R(SEXP json_str) {
R_API_BEGIN();
@@ -527,8 +529,14 @@ XGB_DLL SEXP XGDMatrixSetStrFeatureInfo_R(SEXP handle, SEXP field, SEXP array) {
}
SEXP str_info_holder = PROTECT(Rf_allocVector(VECSXP, len));
for (size_t i = 0; i < len; ++i) {
SET_VECTOR_ELT(str_info_holder, i, Rf_asChar(VECTOR_ELT(array, i)));
if (TYPEOF(array) == STRSXP) {
for (size_t i = 0; i < len; ++i) {
SET_VECTOR_ELT(str_info_holder, i, STRING_ELT(array, i));
}
} else {
for (size_t i = 0; i < len; ++i) {
SET_VECTOR_ELT(str_info_holder, i, Rf_asChar(VECTOR_ELT(array, i)));
}
}
SEXP field_ = PROTECT(Rf_asChar(field));
@@ -614,6 +622,14 @@ XGB_DLL SEXP XGDMatrixNumCol_R(SEXP handle) {
return ScalarInteger(static_cast<int>(ncol));
}
XGB_DLL SEXP XGDuplicate_R(SEXP obj) {
return Rf_duplicate(obj);
}
XGB_DLL SEXP XGPointerEqComparison_R(SEXP obj1, SEXP obj2) {
return Rf_ScalarLogical(R_ExternalPtrAddr(obj1) == R_ExternalPtrAddr(obj2));
}
XGB_DLL SEXP XGDMatrixGetQuantileCut_R(SEXP handle) {
const char *out_names[] = {"indptr", "data", ""};
SEXP continuation_token = Rf_protect(R_MakeUnwindCont());
@@ -682,14 +698,134 @@ XGB_DLL SEXP XGDMatrixGetDataAsCSR_R(SEXP handle) {
}
// functions related to booster
void _BoosterFinalizer(SEXP ext) {
if (R_ExternalPtrAddr(ext) == NULL) return;
CHECK_CALL(XGBoosterFree(R_ExternalPtrAddr(ext)));
R_ClearExternalPtr(ext);
namespace {
void _BoosterFinalizer(SEXP R_ptr) {
if (R_ExternalPtrAddr(R_ptr) == NULL) return;
CHECK_CALL(XGBoosterFree(R_ExternalPtrAddr(R_ptr)));
R_ClearExternalPtr(R_ptr);
}
/* Booster is represented as an altrep list with one element which
corresponds to an 'externalptr' holding the C object, forbidding
modification by not implementing setters, and adding custom serialization. */
R_altrep_class_t XGBAltrepPointerClass;
R_xlen_t XGBAltrepPointerLength_R(SEXP R_altrepped_obj) {
return 1;
}
SEXP XGBAltrepPointerGetElt_R(SEXP R_altrepped_obj, R_xlen_t idx) {
return R_altrep_data1(R_altrepped_obj);
}
SEXP XGBMakeEmptyAltrep() {
SEXP class_name = Rf_protect(Rf_mkString("xgb.Booster"));
SEXP elt_names = Rf_protect(Rf_mkString("ptr"));
SEXP R_ptr = Rf_protect(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
SEXP R_altrepped_obj = Rf_protect(R_new_altrep(XGBAltrepPointerClass, R_ptr, R_NilValue));
Rf_setAttrib(R_altrepped_obj, R_NamesSymbol, elt_names);
Rf_setAttrib(R_altrepped_obj, R_ClassSymbol, class_name);
Rf_unprotect(4);
return R_altrepped_obj;
}
/* Note: the idea for separating this function from the one above is to be
able to trigger all R allocations first before doing non-R allocations. */
void XGBAltrepSetPointer(SEXP R_altrepped_obj, BoosterHandle handle) {
SEXP R_ptr = R_altrep_data1(R_altrepped_obj);
R_SetExternalPtrAddr(R_ptr, handle);
R_RegisterCFinalizerEx(R_ptr, _BoosterFinalizer, TRUE);
}
SEXP XGBAltrepSerializer_R(SEXP R_altrepped_obj) {
R_API_BEGIN();
BoosterHandle handle = R_ExternalPtrAddr(R_altrep_data1(R_altrepped_obj));
char const *serialized_bytes;
bst_ulong serialized_length;
CHECK_CALL(XGBoosterSerializeToBuffer(
handle, &serialized_length, &serialized_bytes));
SEXP R_state = Rf_protect(Rf_allocVector(RAWSXP, serialized_length));
if (serialized_length != 0) {
std::memcpy(RAW(R_state), serialized_bytes, serialized_length);
}
Rf_unprotect(1);
return R_state;
R_API_END();
return R_NilValue; /* <- should not be reached */
}
SEXP XGBAltrepDeserializer_R(SEXP unused, SEXP R_state) {
SEXP R_altrepped_obj = Rf_protect(XGBMakeEmptyAltrep());
R_API_BEGIN();
BoosterHandle handle = nullptr;
CHECK_CALL(XGBoosterCreate(nullptr, 0, &handle));
int res_code = XGBoosterUnserializeFromBuffer(handle,
RAW(R_state),
Rf_xlength(R_state));
if (res_code != 0) {
XGBoosterFree(handle);
}
CHECK_CALL(res_code);
XGBAltrepSetPointer(R_altrepped_obj, handle);
R_API_END();
Rf_unprotect(1);
return R_altrepped_obj;
}
// https://purrple.cat/blog/2018/10/14/altrep-and-cpp/
Rboolean XGBAltrepInspector_R(
SEXP x, int pre, int deep, int pvec,
void (*inspect_subtree)(SEXP, int, int, int)) {
Rprintf("Altrepped external pointer [address:%p]\n",
R_ExternalPtrAddr(R_altrep_data1(x)));
return TRUE;
}
SEXP XGBAltrepDuplicate_R(SEXP R_altrepped_obj, Rboolean deep) {
R_API_BEGIN();
if (!deep) {
SEXP out = Rf_protect(XGBMakeEmptyAltrep());
R_set_altrep_data1(out, R_altrep_data1(R_altrepped_obj));
Rf_unprotect(1);
return out;
} else {
SEXP out = Rf_protect(XGBMakeEmptyAltrep());
char const *serialized_bytes;
bst_ulong serialized_length;
CHECK_CALL(XGBoosterSerializeToBuffer(
R_ExternalPtrAddr(R_altrep_data1(R_altrepped_obj)),
&serialized_length, &serialized_bytes));
BoosterHandle new_handle = nullptr;
CHECK_CALL(XGBoosterCreate(nullptr, 0, &new_handle));
int res_code = XGBoosterUnserializeFromBuffer(new_handle,
serialized_bytes,
serialized_length);
if (res_code != 0) {
XGBoosterFree(new_handle);
}
CHECK_CALL(res_code);
XGBAltrepSetPointer(out, new_handle);
Rf_unprotect(1);
return out;
}
R_API_END();
return R_NilValue; /* <- should not be reached */
}
} /* namespace */
XGB_DLL void XGBInitializeAltrepClass_R(DllInfo *dll) {
XGBAltrepPointerClass = R_make_altlist_class("XGBAltrepPointerClass", "xgboost", dll);
R_set_altrep_Length_method(XGBAltrepPointerClass, XGBAltrepPointerLength_R);
R_set_altlist_Elt_method(XGBAltrepPointerClass, XGBAltrepPointerGetElt_R);
R_set_altrep_Inspect_method(XGBAltrepPointerClass, XGBAltrepInspector_R);
R_set_altrep_Serialized_state_method(XGBAltrepPointerClass, XGBAltrepSerializer_R);
R_set_altrep_Unserialize_method(XGBAltrepPointerClass, XGBAltrepDeserializer_R);
R_set_altrep_Duplicate_method(XGBAltrepPointerClass, XGBAltrepDuplicate_R);
}
XGB_DLL SEXP XGBoosterCreate_R(SEXP dmats) {
SEXP ret = PROTECT(R_MakeExternalPtr(nullptr, R_NilValue, R_NilValue));
SEXP out = Rf_protect(XGBMakeEmptyAltrep());
R_API_BEGIN();
R_xlen_t len = Rf_xlength(dmats);
BoosterHandle handle;
@@ -703,33 +839,104 @@ XGB_DLL SEXP XGBoosterCreate_R(SEXP dmats) {
res_code = XGBoosterCreate(BeginPtr(dvec), dvec.size(), &handle);
}
CHECK_CALL(res_code);
R_SetExternalPtrAddr(ret, handle);
R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
XGBAltrepSetPointer(out, handle);
R_API_END();
UNPROTECT(1);
return ret;
Rf_unprotect(1);
return out;
}
XGB_DLL SEXP XGBoosterCreateInEmptyObj_R(SEXP dmats, SEXP R_handle) {
XGB_DLL SEXP XGBoosterCopyInfoFromDMatrix_R(SEXP booster, SEXP dmat) {
R_API_BEGIN();
R_xlen_t len = Rf_xlength(dmats);
BoosterHandle handle;
char const **feature_names;
bst_ulong len_feature_names = 0;
CHECK_CALL(XGDMatrixGetStrFeatureInfo(R_ExternalPtrAddr(dmat),
"feature_name",
&len_feature_names,
&feature_names));
if (len_feature_names) {
CHECK_CALL(XGBoosterSetStrFeatureInfo(R_ExternalPtrAddr(booster),
"feature_name",
feature_names,
len_feature_names));
}
char const **feature_types;
bst_ulong len_feature_types = 0;
CHECK_CALL(XGDMatrixGetStrFeatureInfo(R_ExternalPtrAddr(dmat),
"feature_type",
&len_feature_types,
&feature_types));
if (len_feature_types) {
CHECK_CALL(XGBoosterSetStrFeatureInfo(R_ExternalPtrAddr(booster),
"feature_type",
feature_types,
len_feature_types));
}
R_API_END();
return R_NilValue;
}
XGB_DLL SEXP XGBoosterSetStrFeatureInfo_R(SEXP handle, SEXP field, SEXP features) {
R_API_BEGIN();
SEXP field_char = Rf_protect(Rf_asChar(field));
bst_ulong len_features = Rf_xlength(features);
int res_code;
{
std::vector<void*> dvec(len);
for (R_xlen_t i = 0; i < len; ++i) {
dvec[i] = R_ExternalPtrAddr(VECTOR_ELT(dmats, i));
std::vector<const char*> str_arr(len_features);
for (bst_ulong idx = 0; idx < len_features; idx++) {
str_arr[idx] = CHAR(STRING_ELT(features, idx));
}
res_code = XGBoosterCreate(BeginPtr(dvec), dvec.size(), &handle);
res_code = XGBoosterSetStrFeatureInfo(R_ExternalPtrAddr(handle),
CHAR(field_char),
str_arr.data(),
len_features);
}
CHECK_CALL(res_code);
R_SetExternalPtrAddr(R_handle, handle);
R_RegisterCFinalizerEx(R_handle, _BoosterFinalizer, TRUE);
Rf_unprotect(1);
R_API_END();
return R_NilValue;
}
XGB_DLL SEXP XGBoosterGetStrFeatureInfo_R(SEXP handle, SEXP field) {
R_API_BEGIN();
bst_ulong len;
const char **out_features;
SEXP field_char = Rf_protect(Rf_asChar(field));
CHECK_CALL(XGBoosterGetStrFeatureInfo(R_ExternalPtrAddr(handle),
CHAR(field_char), &len, &out_features));
SEXP out = Rf_protect(Rf_allocVector(STRSXP, len));
for (bst_ulong idx = 0; idx < len; idx++) {
SET_STRING_ELT(out, idx, Rf_mkChar(out_features[idx]));
}
Rf_unprotect(2);
return out;
R_API_END();
return R_NilValue; /* <- should not be reached */
}
XGB_DLL SEXP XGBoosterBoostedRounds_R(SEXP handle) {
SEXP out = Rf_protect(Rf_allocVector(INTSXP, 1));
R_API_BEGIN();
CHECK_CALL(XGBoosterBoostedRounds(R_ExternalPtrAddr(handle), INTEGER(out)));
R_API_END();
Rf_unprotect(1);
return out;
}
/* Note: R's integer class is 32-bit-and-signed only, while xgboost
supports more, so it returns it as a floating point instead */
XGB_DLL SEXP XGBoosterGetNumFeature_R(SEXP handle) {
SEXP out = Rf_protect(Rf_allocVector(REALSXP, 1));
R_API_BEGIN();
bst_ulong res;
CHECK_CALL(XGBoosterGetNumFeature(R_ExternalPtrAddr(handle), &res));
REAL(out)[0] = static_cast<double>(res);
R_API_END();
Rf_unprotect(1);
return out;
}
XGB_DLL SEXP XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val) {
R_API_BEGIN();
SEXP name_ = PROTECT(Rf_asChar(name));
@@ -745,8 +952,8 @@ XGB_DLL SEXP XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val) {
XGB_DLL SEXP XGBoosterUpdateOneIter_R(SEXP handle, SEXP iter, SEXP dtrain) {
R_API_BEGIN();
CHECK_CALL(XGBoosterUpdateOneIter(R_ExternalPtrAddr(handle),
asInteger(iter),
R_ExternalPtrAddr(dtrain)));
Rf_asInteger(iter),
R_ExternalPtrAddr(dtrain)));
R_API_END();
return R_NilValue;
}

View File

@@ -8,7 +8,9 @@
#define XGBOOST_R_H_ // NOLINT(*)
#include <R.h>
#include <Rinternals.h>
#include <R_ext/Altrep.h>
#include <R_ext/Random.h>
#include <Rmath.h>
@@ -143,6 +145,25 @@ XGB_DLL SEXP XGDMatrixNumRow_R(SEXP handle);
*/
XGB_DLL SEXP XGDMatrixNumCol_R(SEXP handle);
/*!
* \brief Call R C-level function 'duplicate'
* \param obj Object to duplicate
*/
XGB_DLL SEXP XGDuplicate_R(SEXP obj);
/*!
* \brief Equality comparison for two pointers
* \param obj1 R 'externalptr'
* \param obj2 R 'externalptr'
*/
XGB_DLL SEXP XGPointerEqComparison_R(SEXP obj1, SEXP obj2);
/*!
* \brief Register the Altrep class used for the booster
* \param dll DLL info as provided by R_init
*/
XGB_DLL void XGBInitializeAltrepClass_R(DllInfo *dll);
/*!
* \brief return the quantile cuts used for the histogram method
* \param handle an instance of data matrix
@@ -174,13 +195,37 @@ XGB_DLL SEXP XGDMatrixGetDataAsCSR_R(SEXP handle);
*/
XGB_DLL SEXP XGBoosterCreate_R(SEXP dmats);
/*!
* \brief copy information about features from a DMatrix into a Booster
* \param booster R 'externalptr' pointing to a booster object
* \param dmat R 'externalptr' pointing to a DMatrix object
*/
XGB_DLL SEXP XGBoosterCopyInfoFromDMatrix_R(SEXP booster, SEXP dmat);
/*!
* \brief create xgboost learner, saving the pointer into an existing R object
* \param dmats a list of dmatrix handles that will be cached
* \param R_handle a clean R external pointer (not holding any object)
* \brief handle R 'externalptr' holding the booster object
* \param field field name
* \param features features to set for the field
*/
XGB_DLL SEXP XGBoosterCreateInEmptyObj_R(SEXP dmats, SEXP R_handle);
XGB_DLL SEXP XGBoosterSetStrFeatureInfo_R(SEXP handle, SEXP field, SEXP features);
/*!
* \brief handle R 'externalptr' holding the booster object
* \param field field name
*/
XGB_DLL SEXP XGBoosterGetStrFeatureInfo_R(SEXP handle, SEXP field);
/*!
* \brief Get the number of boosted rounds from a model
* \param handle R 'externalptr' holding the booster object
*/
XGB_DLL SEXP XGBoosterBoostedRounds_R(SEXP handle);
/*!
* \brief Get the number of features to which the model was fitted
* \param handle R 'externalptr' holding the booster object
*/
XGB_DLL SEXP XGBoosterGetNumFeature_R(SEXP handle);
/*!
* \brief set parameters