[R] Use inplace predict (#9829)

---------

Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
david-cortes
2024-02-23 19:03:54 +01:00
committed by GitHub
parent 729fd97196
commit f7005d32c1
7 changed files with 450 additions and 46 deletions

View File

@@ -37,6 +37,9 @@ extern SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value);
extern SEXP XGBoosterSerializeToBuffer_R(SEXP handle);
extern SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw);
extern SEXP XGBoosterPredictFromDMatrix_R(SEXP, SEXP, SEXP);
extern SEXP XGBoosterPredictFromDense_R(SEXP, SEXP, SEXP, SEXP, SEXP);
extern SEXP XGBoosterPredictFromCSR_R(SEXP, SEXP, SEXP, SEXP, SEXP);
extern SEXP XGBoosterPredictFromColumnar_R(SEXP, SEXP, SEXP, SEXP, SEXP);
extern SEXP XGBoosterSaveModel_R(SEXP, SEXP);
extern SEXP XGBoosterSetAttr_R(SEXP, SEXP, SEXP);
extern SEXP XGBoosterSetParam_R(SEXP, SEXP, SEXP);
@@ -96,6 +99,9 @@ static const R_CallMethodDef CallEntries[] = {
{"XGBoosterSerializeToBuffer_R", (DL_FUNC) &XGBoosterSerializeToBuffer_R, 1},
{"XGBoosterUnserializeFromBuffer_R", (DL_FUNC) &XGBoosterUnserializeFromBuffer_R, 2},
{"XGBoosterPredictFromDMatrix_R", (DL_FUNC) &XGBoosterPredictFromDMatrix_R, 3},
{"XGBoosterPredictFromDense_R", (DL_FUNC) &XGBoosterPredictFromDense_R, 5},
{"XGBoosterPredictFromCSR_R", (DL_FUNC) &XGBoosterPredictFromCSR_R, 5},
{"XGBoosterPredictFromColumnar_R", (DL_FUNC) &XGBoosterPredictFromColumnar_R, 5},
{"XGBoosterSaveModel_R", (DL_FUNC) &XGBoosterSaveModel_R, 2},
{"XGBoosterSetAttr_R", (DL_FUNC) &XGBoosterSetAttr_R, 3},
{"XGBoosterSetParam_R", (DL_FUNC) &XGBoosterSetParam_R, 3},

View File

@@ -13,6 +13,7 @@
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <memory>
#include <limits>
#include <sstream>
#include <string>
@@ -207,25 +208,24 @@ SEXP SafeAllocInteger(size_t size, SEXP continuation_token) {
return xgboost::Json::Dump(jinterface);
}
[[nodiscard]] std::string MakeJsonConfigForArray(SEXP missing, SEXP n_threads, SEXPTYPE arr_type) {
using namespace ::xgboost; // NOLINT
Json jconfig{Object{}};
const SEXPTYPE missing_type = TYPEOF(missing);
if (Rf_isNull(missing) || (missing_type == REALSXP && ISNAN(Rf_asReal(missing))) ||
(missing_type == LGLSXP && Rf_asLogical(missing) == R_NaInt) ||
(missing_type == INTSXP && Rf_asInteger(missing) == R_NaInt)) {
void AddMissingToJson(xgboost::Json *jconfig, SEXP missing, SEXPTYPE arr_type) {
if (Rf_isNull(missing) || ISNAN(Rf_asReal(missing))) {
// missing is not specified
if (arr_type == REALSXP) {
jconfig["missing"] = std::numeric_limits<double>::quiet_NaN();
(*jconfig)["missing"] = std::numeric_limits<double>::quiet_NaN();
} else {
jconfig["missing"] = R_NaInt;
(*jconfig)["missing"] = R_NaInt;
}
} else {
// missing specified
jconfig["missing"] = Rf_asReal(missing);
(*jconfig)["missing"] = Rf_asReal(missing);
}
}
[[nodiscard]] std::string MakeJsonConfigForArray(SEXP missing, SEXP n_threads, SEXPTYPE arr_type) {
using namespace ::xgboost; // NOLINT
Json jconfig{Object{}};
AddMissingToJson(&jconfig, missing, arr_type);
jconfig["nthread"] = Rf_asInteger(n_threads);
return Json::Dump(jconfig);
}
@@ -411,7 +411,7 @@ XGB_DLL SEXP XGDMatrixCreateFromDF_R(SEXP df, SEXP missing, SEXP n_threads) {
DMatrixHandle handle;
std::int32_t rc{0};
{
std::string sinterface = MakeArrayInterfaceFromRDataFrame(df);
const std::string sinterface = MakeArrayInterfaceFromRDataFrame(df);
xgboost::Json jconfig{xgboost::Object{}};
jconfig["missing"] = asReal(missing);
jconfig["nthread"] = asInteger(n_threads);
@@ -463,7 +463,7 @@ XGB_DLL SEXP XGDMatrixCreateFromCSC_R(SEXP indptr, SEXP indices, SEXP data, SEXP
Json jconfig{Object{}};
// Construct configuration
jconfig["nthread"] = Integer{threads};
jconfig["missing"] = xgboost::Number{asReal(missing)};
AddMissingToJson(&jconfig, missing, TYPEOF(data));
std::string config;
Json::Dump(jconfig, &config);
res_code = XGDMatrixCreateFromCSC(sindptr.c_str(), sindices.c_str(), sdata.c_str(), nrow,
@@ -498,7 +498,7 @@ XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data, SEXP
Json jconfig{Object{}};
// Construct configuration
jconfig["nthread"] = Integer{threads};
jconfig["missing"] = xgboost::Number{asReal(missing)};
AddMissingToJson(&jconfig, missing, TYPEOF(data));
std::string config;
Json::Dump(jconfig, &config);
res_code = XGDMatrixCreateFromCSR(sindptr.c_str(), sindices.c_str(), sdata.c_str(), ncol,
@@ -1247,7 +1247,60 @@ XGB_DLL SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evn
return mkString(ret);
}
XGB_DLL SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_config) {
namespace {
struct ProxyDmatrixError : public std::exception {};
struct ProxyDmatrixWrapper {
DMatrixHandle proxy_dmat_handle;
ProxyDmatrixWrapper() {
int res_code = XGProxyDMatrixCreate(&this->proxy_dmat_handle);
if (res_code != 0) {
throw ProxyDmatrixError();
}
}
~ProxyDmatrixWrapper() {
if (this->proxy_dmat_handle) {
XGDMatrixFree(this->proxy_dmat_handle);
this->proxy_dmat_handle = nullptr;
}
}
DMatrixHandle get_handle() {
return this->proxy_dmat_handle;
}
};
std::unique_ptr<ProxyDmatrixWrapper> GetProxyDMatrixWithBaseMargin(SEXP base_margin) {
if (Rf_isNull(base_margin)) {
return std::unique_ptr<ProxyDmatrixWrapper>(nullptr);
}
SEXP base_margin_dim = Rf_getAttrib(base_margin, R_DimSymbol);
int res_code;
try {
const std::string array_str = Rf_isNull(base_margin_dim)?
MakeArrayInterfaceFromRVector(base_margin) : MakeArrayInterfaceFromRMat(base_margin);
std::unique_ptr<ProxyDmatrixWrapper> proxy_dmat(new ProxyDmatrixWrapper());
res_code = XGDMatrixSetInfoFromInterface(proxy_dmat->get_handle(),
"base_margin",
array_str.c_str());
if (res_code != 0) {
throw ProxyDmatrixError();
}
return proxy_dmat;
} catch(ProxyDmatrixError &err) {
Rf_error("%s", XGBGetLastError());
}
}
enum class PredictionInputType {DMatrix, DenseMatrix, CSRMatrix, DataFrame};
SEXP XGBoosterPredictGeneric(SEXP handle, SEXP input_data, SEXP json_config,
PredictionInputType input_type, SEXP missing,
SEXP base_margin) {
SEXP r_out_shape;
SEXP r_out_result;
SEXP r_out = PROTECT(allocVector(VECSXP, 2));
@@ -1259,9 +1312,79 @@ XGB_DLL SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_con
bst_ulong out_dim;
bst_ulong const *out_shape;
float const *out_result;
CHECK_CALL(XGBoosterPredictFromDMatrix(R_ExternalPtrAddr(handle),
R_ExternalPtrAddr(dmat), c_json_config,
&out_shape, &out_dim, &out_result));
int res_code;
{
switch (input_type) {
case PredictionInputType::DMatrix: {
res_code = XGBoosterPredictFromDMatrix(R_ExternalPtrAddr(handle),
R_ExternalPtrAddr(input_data), c_json_config,
&out_shape, &out_dim, &out_result);
break;
}
case PredictionInputType::CSRMatrix: {
std::unique_ptr<ProxyDmatrixWrapper> proxy_dmat = GetProxyDMatrixWithBaseMargin(
base_margin);
DMatrixHandle proxy_dmat_handle = proxy_dmat.get()? proxy_dmat->get_handle() : nullptr;
SEXP indptr = VECTOR_ELT(input_data, 0);
SEXP indices = VECTOR_ELT(input_data, 1);
SEXP data = VECTOR_ELT(input_data, 2);
const int ncol_csr = Rf_asInteger(VECTOR_ELT(input_data, 3));
const SEXPTYPE type_data = TYPEOF(data);
CHECK_EQ(type_data, REALSXP);
std::string sindptr, sindices, sdata;
CreateFromSparse(indptr, indices, data, &sindptr, &sindices, &sdata);
xgboost::StringView json_str(c_json_config);
xgboost::Json new_json = xgboost::Json::Load(json_str);
AddMissingToJson(&new_json, missing, type_data);
const std::string new_c_json = xgboost::Json::Dump(new_json);
res_code = XGBoosterPredictFromCSR(
R_ExternalPtrAddr(handle), sindptr.c_str(), sindices.c_str(), sdata.c_str(),
ncol_csr, new_c_json.c_str(), proxy_dmat_handle, &out_shape, &out_dim, &out_result);
break;
}
case PredictionInputType::DenseMatrix: {
std::unique_ptr<ProxyDmatrixWrapper> proxy_dmat = GetProxyDMatrixWithBaseMargin(
base_margin);
DMatrixHandle proxy_dmat_handle = proxy_dmat.get()? proxy_dmat->get_handle() : nullptr;
const std::string array_str = MakeArrayInterfaceFromRMat(input_data);
xgboost::StringView json_str(c_json_config);
xgboost::Json new_json = xgboost::Json::Load(json_str);
AddMissingToJson(&new_json, missing, TYPEOF(input_data));
const std::string new_c_json = xgboost::Json::Dump(new_json);
res_code = XGBoosterPredictFromDense(
R_ExternalPtrAddr(handle), array_str.c_str(), new_c_json.c_str(),
proxy_dmat_handle, &out_shape, &out_dim, &out_result);
break;
}
case PredictionInputType::DataFrame: {
std::unique_ptr<ProxyDmatrixWrapper> proxy_dmat = GetProxyDMatrixWithBaseMargin(
base_margin);
DMatrixHandle proxy_dmat_handle = proxy_dmat.get()? proxy_dmat->get_handle() : nullptr;
const std::string df_str = MakeArrayInterfaceFromRDataFrame(input_data);
xgboost::StringView json_str(c_json_config);
xgboost::Json new_json = xgboost::Json::Load(json_str);
AddMissingToJson(&new_json, missing, REALSXP);
const std::string new_c_json = xgboost::Json::Dump(new_json);
res_code = XGBoosterPredictFromColumnar(
R_ExternalPtrAddr(handle), df_str.c_str(), new_c_json.c_str(),
proxy_dmat_handle, &out_shape, &out_dim, &out_result);
break;
}
}
}
CHECK_CALL(res_code);
r_out_shape = PROTECT(allocVector(INTSXP, out_dim));
size_t len = 1;
@@ -1282,6 +1405,31 @@ XGB_DLL SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_con
return r_out;
}
} // namespace
XGB_DLL SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_config) {
return XGBoosterPredictGeneric(handle, dmat, json_config,
PredictionInputType::DMatrix, R_NilValue, R_NilValue);
}
XGB_DLL SEXP XGBoosterPredictFromDense_R(SEXP handle, SEXP R_mat, SEXP missing,
SEXP json_config, SEXP base_margin) {
return XGBoosterPredictGeneric(handle, R_mat, json_config,
PredictionInputType::DenseMatrix, missing, base_margin);
}
XGB_DLL SEXP XGBoosterPredictFromCSR_R(SEXP handle, SEXP lst, SEXP missing,
SEXP json_config, SEXP base_margin) {
return XGBoosterPredictGeneric(handle, lst, json_config,
PredictionInputType::CSRMatrix, missing, base_margin);
}
XGB_DLL SEXP XGBoosterPredictFromColumnar_R(SEXP handle, SEXP R_df, SEXP missing,
SEXP json_config, SEXP base_margin) {
return XGBoosterPredictGeneric(handle, R_df, json_config,
PredictionInputType::DataFrame, missing, base_margin);
}
XGB_DLL SEXP XGBoosterLoadModel_R(SEXP handle, SEXP fname) {
R_API_BEGIN();
CHECK_CALL(XGBoosterLoadModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname))));

View File

@@ -371,6 +371,50 @@ XGB_DLL SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evn
* \return A list containing 2 vectors, first one for shape while second one for prediction result.
*/
XGB_DLL SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_config);
/*!
* \brief Run prediction on R dense matrix
* \param handle handle
* \param R_mat R matrix
* \param missing missing value
* \param json_config See `XGBoosterPredictFromDense` in xgboost c_api.h. Doesn't include 'missing'
* \param base_margin base margin for the prediction
*
* \return A list containing 2 vectors, first one for shape while second one for prediction result.
*/
XGB_DLL SEXP XGBoosterPredictFromDense_R(SEXP handle, SEXP R_mat, SEXP missing,
SEXP json_config, SEXP base_margin);
/*!
* \brief Run prediction on R CSR matrix
* \param handle handle
* \param lst An R list, containing, in this order:
* (a) 'p' array (a.k.a. indptr)
* (b) 'j' array (a.k.a. indices)
* (c) 'x' array (a.k.a. data / values)
* (d) number of columns
* \param missing missing value
* \param json_config See `XGBoosterPredictFromCSR` in xgboost c_api.h. Doesn't include 'missing'
* \param base_margin base margin for the prediction
*
* \return A list containing 2 vectors, first one for shape while second one for prediction result.
*/
XGB_DLL SEXP XGBoosterPredictFromCSR_R(SEXP handle, SEXP lst, SEXP missing,
SEXP json_config, SEXP base_margin);
/*!
* \brief Run prediction on R data.frame
* \param handle handle
* \param R_df R data.frame
* \param missing missing value
* \param json_config See `XGBoosterPredictFromDense` in xgboost c_api.h. Doesn't include 'missing'
* \param base_margin base margin for the prediction
*
* \return A list containing 2 vectors, first one for shape while second one for prediction result.
*/
XGB_DLL SEXP XGBoosterPredictFromColumnar_R(SEXP handle, SEXP R_df, SEXP missing,
SEXP json_config, SEXP base_margin);
/*!
* \brief load model from existing file
* \param handle handle