[R] Use inplace predict (#9829)
--------- Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -37,6 +37,9 @@ extern SEXP XGBoosterLoadJsonConfig_R(SEXP handle, SEXP value);
|
||||
extern SEXP XGBoosterSerializeToBuffer_R(SEXP handle);
|
||||
extern SEXP XGBoosterUnserializeFromBuffer_R(SEXP handle, SEXP raw);
|
||||
extern SEXP XGBoosterPredictFromDMatrix_R(SEXP, SEXP, SEXP);
|
||||
extern SEXP XGBoosterPredictFromDense_R(SEXP, SEXP, SEXP, SEXP, SEXP);
|
||||
extern SEXP XGBoosterPredictFromCSR_R(SEXP, SEXP, SEXP, SEXP, SEXP);
|
||||
extern SEXP XGBoosterPredictFromColumnar_R(SEXP, SEXP, SEXP, SEXP, SEXP);
|
||||
extern SEXP XGBoosterSaveModel_R(SEXP, SEXP);
|
||||
extern SEXP XGBoosterSetAttr_R(SEXP, SEXP, SEXP);
|
||||
extern SEXP XGBoosterSetParam_R(SEXP, SEXP, SEXP);
|
||||
@@ -96,6 +99,9 @@ static const R_CallMethodDef CallEntries[] = {
|
||||
{"XGBoosterSerializeToBuffer_R", (DL_FUNC) &XGBoosterSerializeToBuffer_R, 1},
|
||||
{"XGBoosterUnserializeFromBuffer_R", (DL_FUNC) &XGBoosterUnserializeFromBuffer_R, 2},
|
||||
{"XGBoosterPredictFromDMatrix_R", (DL_FUNC) &XGBoosterPredictFromDMatrix_R, 3},
|
||||
{"XGBoosterPredictFromDense_R", (DL_FUNC) &XGBoosterPredictFromDense_R, 5},
|
||||
{"XGBoosterPredictFromCSR_R", (DL_FUNC) &XGBoosterPredictFromCSR_R, 5},
|
||||
{"XGBoosterPredictFromColumnar_R", (DL_FUNC) &XGBoosterPredictFromColumnar_R, 5},
|
||||
{"XGBoosterSaveModel_R", (DL_FUNC) &XGBoosterSaveModel_R, 2},
|
||||
{"XGBoosterSetAttr_R", (DL_FUNC) &XGBoosterSetAttr_R, 3},
|
||||
{"XGBoosterSetParam_R", (DL_FUNC) &XGBoosterSetParam_R, 3},
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <limits>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
@@ -207,25 +208,24 @@ SEXP SafeAllocInteger(size_t size, SEXP continuation_token) {
|
||||
return xgboost::Json::Dump(jinterface);
|
||||
}
|
||||
|
||||
[[nodiscard]] std::string MakeJsonConfigForArray(SEXP missing, SEXP n_threads, SEXPTYPE arr_type) {
|
||||
using namespace ::xgboost; // NOLINT
|
||||
Json jconfig{Object{}};
|
||||
|
||||
const SEXPTYPE missing_type = TYPEOF(missing);
|
||||
if (Rf_isNull(missing) || (missing_type == REALSXP && ISNAN(Rf_asReal(missing))) ||
|
||||
(missing_type == LGLSXP && Rf_asLogical(missing) == R_NaInt) ||
|
||||
(missing_type == INTSXP && Rf_asInteger(missing) == R_NaInt)) {
|
||||
void AddMissingToJson(xgboost::Json *jconfig, SEXP missing, SEXPTYPE arr_type) {
|
||||
if (Rf_isNull(missing) || ISNAN(Rf_asReal(missing))) {
|
||||
// missing is not specified
|
||||
if (arr_type == REALSXP) {
|
||||
jconfig["missing"] = std::numeric_limits<double>::quiet_NaN();
|
||||
(*jconfig)["missing"] = std::numeric_limits<double>::quiet_NaN();
|
||||
} else {
|
||||
jconfig["missing"] = R_NaInt;
|
||||
(*jconfig)["missing"] = R_NaInt;
|
||||
}
|
||||
} else {
|
||||
// missing specified
|
||||
jconfig["missing"] = Rf_asReal(missing);
|
||||
(*jconfig)["missing"] = Rf_asReal(missing);
|
||||
}
|
||||
}
|
||||
|
||||
[[nodiscard]] std::string MakeJsonConfigForArray(SEXP missing, SEXP n_threads, SEXPTYPE arr_type) {
|
||||
using namespace ::xgboost; // NOLINT
|
||||
Json jconfig{Object{}};
|
||||
AddMissingToJson(&jconfig, missing, arr_type);
|
||||
jconfig["nthread"] = Rf_asInteger(n_threads);
|
||||
return Json::Dump(jconfig);
|
||||
}
|
||||
@@ -411,7 +411,7 @@ XGB_DLL SEXP XGDMatrixCreateFromDF_R(SEXP df, SEXP missing, SEXP n_threads) {
|
||||
DMatrixHandle handle;
|
||||
std::int32_t rc{0};
|
||||
{
|
||||
std::string sinterface = MakeArrayInterfaceFromRDataFrame(df);
|
||||
const std::string sinterface = MakeArrayInterfaceFromRDataFrame(df);
|
||||
xgboost::Json jconfig{xgboost::Object{}};
|
||||
jconfig["missing"] = asReal(missing);
|
||||
jconfig["nthread"] = asInteger(n_threads);
|
||||
@@ -463,7 +463,7 @@ XGB_DLL SEXP XGDMatrixCreateFromCSC_R(SEXP indptr, SEXP indices, SEXP data, SEXP
|
||||
Json jconfig{Object{}};
|
||||
// Construct configuration
|
||||
jconfig["nthread"] = Integer{threads};
|
||||
jconfig["missing"] = xgboost::Number{asReal(missing)};
|
||||
AddMissingToJson(&jconfig, missing, TYPEOF(data));
|
||||
std::string config;
|
||||
Json::Dump(jconfig, &config);
|
||||
res_code = XGDMatrixCreateFromCSC(sindptr.c_str(), sindices.c_str(), sdata.c_str(), nrow,
|
||||
@@ -498,7 +498,7 @@ XGB_DLL SEXP XGDMatrixCreateFromCSR_R(SEXP indptr, SEXP indices, SEXP data, SEXP
|
||||
Json jconfig{Object{}};
|
||||
// Construct configuration
|
||||
jconfig["nthread"] = Integer{threads};
|
||||
jconfig["missing"] = xgboost::Number{asReal(missing)};
|
||||
AddMissingToJson(&jconfig, missing, TYPEOF(data));
|
||||
std::string config;
|
||||
Json::Dump(jconfig, &config);
|
||||
res_code = XGDMatrixCreateFromCSR(sindptr.c_str(), sindices.c_str(), sdata.c_str(), ncol,
|
||||
@@ -1247,7 +1247,60 @@ XGB_DLL SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evn
|
||||
return mkString(ret);
|
||||
}
|
||||
|
||||
XGB_DLL SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_config) {
|
||||
namespace {
|
||||
|
||||
struct ProxyDmatrixError : public std::exception {};
|
||||
|
||||
struct ProxyDmatrixWrapper {
|
||||
DMatrixHandle proxy_dmat_handle;
|
||||
|
||||
ProxyDmatrixWrapper() {
|
||||
int res_code = XGProxyDMatrixCreate(&this->proxy_dmat_handle);
|
||||
if (res_code != 0) {
|
||||
throw ProxyDmatrixError();
|
||||
}
|
||||
}
|
||||
|
||||
~ProxyDmatrixWrapper() {
|
||||
if (this->proxy_dmat_handle) {
|
||||
XGDMatrixFree(this->proxy_dmat_handle);
|
||||
this->proxy_dmat_handle = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
DMatrixHandle get_handle() {
|
||||
return this->proxy_dmat_handle;
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<ProxyDmatrixWrapper> GetProxyDMatrixWithBaseMargin(SEXP base_margin) {
|
||||
if (Rf_isNull(base_margin)) {
|
||||
return std::unique_ptr<ProxyDmatrixWrapper>(nullptr);
|
||||
}
|
||||
|
||||
SEXP base_margin_dim = Rf_getAttrib(base_margin, R_DimSymbol);
|
||||
int res_code;
|
||||
try {
|
||||
const std::string array_str = Rf_isNull(base_margin_dim)?
|
||||
MakeArrayInterfaceFromRVector(base_margin) : MakeArrayInterfaceFromRMat(base_margin);
|
||||
std::unique_ptr<ProxyDmatrixWrapper> proxy_dmat(new ProxyDmatrixWrapper());
|
||||
res_code = XGDMatrixSetInfoFromInterface(proxy_dmat->get_handle(),
|
||||
"base_margin",
|
||||
array_str.c_str());
|
||||
if (res_code != 0) {
|
||||
throw ProxyDmatrixError();
|
||||
}
|
||||
return proxy_dmat;
|
||||
} catch(ProxyDmatrixError &err) {
|
||||
Rf_error("%s", XGBGetLastError());
|
||||
}
|
||||
}
|
||||
|
||||
enum class PredictionInputType {DMatrix, DenseMatrix, CSRMatrix, DataFrame};
|
||||
|
||||
SEXP XGBoosterPredictGeneric(SEXP handle, SEXP input_data, SEXP json_config,
|
||||
PredictionInputType input_type, SEXP missing,
|
||||
SEXP base_margin) {
|
||||
SEXP r_out_shape;
|
||||
SEXP r_out_result;
|
||||
SEXP r_out = PROTECT(allocVector(VECSXP, 2));
|
||||
@@ -1259,9 +1312,79 @@ XGB_DLL SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_con
|
||||
bst_ulong out_dim;
|
||||
bst_ulong const *out_shape;
|
||||
float const *out_result;
|
||||
CHECK_CALL(XGBoosterPredictFromDMatrix(R_ExternalPtrAddr(handle),
|
||||
R_ExternalPtrAddr(dmat), c_json_config,
|
||||
&out_shape, &out_dim, &out_result));
|
||||
|
||||
int res_code;
|
||||
{
|
||||
switch (input_type) {
|
||||
case PredictionInputType::DMatrix: {
|
||||
res_code = XGBoosterPredictFromDMatrix(R_ExternalPtrAddr(handle),
|
||||
R_ExternalPtrAddr(input_data), c_json_config,
|
||||
&out_shape, &out_dim, &out_result);
|
||||
break;
|
||||
}
|
||||
|
||||
case PredictionInputType::CSRMatrix: {
|
||||
std::unique_ptr<ProxyDmatrixWrapper> proxy_dmat = GetProxyDMatrixWithBaseMargin(
|
||||
base_margin);
|
||||
DMatrixHandle proxy_dmat_handle = proxy_dmat.get()? proxy_dmat->get_handle() : nullptr;
|
||||
|
||||
SEXP indptr = VECTOR_ELT(input_data, 0);
|
||||
SEXP indices = VECTOR_ELT(input_data, 1);
|
||||
SEXP data = VECTOR_ELT(input_data, 2);
|
||||
const int ncol_csr = Rf_asInteger(VECTOR_ELT(input_data, 3));
|
||||
const SEXPTYPE type_data = TYPEOF(data);
|
||||
CHECK_EQ(type_data, REALSXP);
|
||||
std::string sindptr, sindices, sdata;
|
||||
CreateFromSparse(indptr, indices, data, &sindptr, &sindices, &sdata);
|
||||
|
||||
xgboost::StringView json_str(c_json_config);
|
||||
xgboost::Json new_json = xgboost::Json::Load(json_str);
|
||||
AddMissingToJson(&new_json, missing, type_data);
|
||||
const std::string new_c_json = xgboost::Json::Dump(new_json);
|
||||
|
||||
res_code = XGBoosterPredictFromCSR(
|
||||
R_ExternalPtrAddr(handle), sindptr.c_str(), sindices.c_str(), sdata.c_str(),
|
||||
ncol_csr, new_c_json.c_str(), proxy_dmat_handle, &out_shape, &out_dim, &out_result);
|
||||
break;
|
||||
}
|
||||
|
||||
case PredictionInputType::DenseMatrix: {
|
||||
std::unique_ptr<ProxyDmatrixWrapper> proxy_dmat = GetProxyDMatrixWithBaseMargin(
|
||||
base_margin);
|
||||
DMatrixHandle proxy_dmat_handle = proxy_dmat.get()? proxy_dmat->get_handle() : nullptr;
|
||||
const std::string array_str = MakeArrayInterfaceFromRMat(input_data);
|
||||
|
||||
xgboost::StringView json_str(c_json_config);
|
||||
xgboost::Json new_json = xgboost::Json::Load(json_str);
|
||||
AddMissingToJson(&new_json, missing, TYPEOF(input_data));
|
||||
const std::string new_c_json = xgboost::Json::Dump(new_json);
|
||||
|
||||
res_code = XGBoosterPredictFromDense(
|
||||
R_ExternalPtrAddr(handle), array_str.c_str(), new_c_json.c_str(),
|
||||
proxy_dmat_handle, &out_shape, &out_dim, &out_result);
|
||||
break;
|
||||
}
|
||||
|
||||
case PredictionInputType::DataFrame: {
|
||||
std::unique_ptr<ProxyDmatrixWrapper> proxy_dmat = GetProxyDMatrixWithBaseMargin(
|
||||
base_margin);
|
||||
DMatrixHandle proxy_dmat_handle = proxy_dmat.get()? proxy_dmat->get_handle() : nullptr;
|
||||
|
||||
const std::string df_str = MakeArrayInterfaceFromRDataFrame(input_data);
|
||||
|
||||
xgboost::StringView json_str(c_json_config);
|
||||
xgboost::Json new_json = xgboost::Json::Load(json_str);
|
||||
AddMissingToJson(&new_json, missing, REALSXP);
|
||||
const std::string new_c_json = xgboost::Json::Dump(new_json);
|
||||
|
||||
res_code = XGBoosterPredictFromColumnar(
|
||||
R_ExternalPtrAddr(handle), df_str.c_str(), new_c_json.c_str(),
|
||||
proxy_dmat_handle, &out_shape, &out_dim, &out_result);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
CHECK_CALL(res_code);
|
||||
|
||||
r_out_shape = PROTECT(allocVector(INTSXP, out_dim));
|
||||
size_t len = 1;
|
||||
@@ -1282,6 +1405,31 @@ XGB_DLL SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_con
|
||||
return r_out;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
XGB_DLL SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_config) {
|
||||
return XGBoosterPredictGeneric(handle, dmat, json_config,
|
||||
PredictionInputType::DMatrix, R_NilValue, R_NilValue);
|
||||
}
|
||||
|
||||
XGB_DLL SEXP XGBoosterPredictFromDense_R(SEXP handle, SEXP R_mat, SEXP missing,
|
||||
SEXP json_config, SEXP base_margin) {
|
||||
return XGBoosterPredictGeneric(handle, R_mat, json_config,
|
||||
PredictionInputType::DenseMatrix, missing, base_margin);
|
||||
}
|
||||
|
||||
XGB_DLL SEXP XGBoosterPredictFromCSR_R(SEXP handle, SEXP lst, SEXP missing,
|
||||
SEXP json_config, SEXP base_margin) {
|
||||
return XGBoosterPredictGeneric(handle, lst, json_config,
|
||||
PredictionInputType::CSRMatrix, missing, base_margin);
|
||||
}
|
||||
|
||||
XGB_DLL SEXP XGBoosterPredictFromColumnar_R(SEXP handle, SEXP R_df, SEXP missing,
|
||||
SEXP json_config, SEXP base_margin) {
|
||||
return XGBoosterPredictGeneric(handle, R_df, json_config,
|
||||
PredictionInputType::DataFrame, missing, base_margin);
|
||||
}
|
||||
|
||||
XGB_DLL SEXP XGBoosterLoadModel_R(SEXP handle, SEXP fname) {
|
||||
R_API_BEGIN();
|
||||
CHECK_CALL(XGBoosterLoadModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname))));
|
||||
|
||||
@@ -371,6 +371,50 @@ XGB_DLL SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evn
|
||||
* \return A list containing 2 vectors, first one for shape while second one for prediction result.
|
||||
*/
|
||||
XGB_DLL SEXP XGBoosterPredictFromDMatrix_R(SEXP handle, SEXP dmat, SEXP json_config);
|
||||
|
||||
/*!
|
||||
* \brief Run prediction on R dense matrix
|
||||
* \param handle handle
|
||||
* \param R_mat R matrix
|
||||
* \param missing missing value
|
||||
* \param json_config See `XGBoosterPredictFromDense` in xgboost c_api.h. Doesn't include 'missing'
|
||||
* \param base_margin base margin for the prediction
|
||||
*
|
||||
* \return A list containing 2 vectors, first one for shape while second one for prediction result.
|
||||
*/
|
||||
XGB_DLL SEXP XGBoosterPredictFromDense_R(SEXP handle, SEXP R_mat, SEXP missing,
|
||||
SEXP json_config, SEXP base_margin);
|
||||
|
||||
/*!
|
||||
* \brief Run prediction on R CSR matrix
|
||||
* \param handle handle
|
||||
* \param lst An R list, containing, in this order:
|
||||
* (a) 'p' array (a.k.a. indptr)
|
||||
* (b) 'j' array (a.k.a. indices)
|
||||
* (c) 'x' array (a.k.a. data / values)
|
||||
* (d) number of columns
|
||||
* \param missing missing value
|
||||
* \param json_config See `XGBoosterPredictFromCSR` in xgboost c_api.h. Doesn't include 'missing'
|
||||
* \param base_margin base margin for the prediction
|
||||
*
|
||||
* \return A list containing 2 vectors, first one for shape while second one for prediction result.
|
||||
*/
|
||||
XGB_DLL SEXP XGBoosterPredictFromCSR_R(SEXP handle, SEXP lst, SEXP missing,
|
||||
SEXP json_config, SEXP base_margin);
|
||||
|
||||
/*!
|
||||
* \brief Run prediction on R data.frame
|
||||
* \param handle handle
|
||||
* \param R_df R data.frame
|
||||
* \param missing missing value
|
||||
* \param json_config See `XGBoosterPredictFromDense` in xgboost c_api.h. Doesn't include 'missing'
|
||||
* \param base_margin base margin for the prediction
|
||||
*
|
||||
* \return A list containing 2 vectors, first one for shape while second one for prediction result.
|
||||
*/
|
||||
XGB_DLL SEXP XGBoosterPredictFromColumnar_R(SEXP handle, SEXP R_df, SEXP missing,
|
||||
SEXP json_config, SEXP base_margin);
|
||||
|
||||
/*!
|
||||
* \brief load model from existing file
|
||||
* \param handle handle
|
||||
|
||||
Reference in New Issue
Block a user