Add checks to C pointer arguments. (#8254)
This commit is contained in:
parent
eb7bbee2c9
commit
3fd331f8f2
@ -62,7 +62,7 @@ void XGBBuildInfoDevice(Json *p_info) {
|
|||||||
|
|
||||||
XGB_DLL int XGBuildInfo(char const **out) {
|
XGB_DLL int XGBuildInfo(char const **out) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK(out) << "Invalid input pointer";
|
xgboost_CHECK_C_ARG_PTR(out);
|
||||||
Json info{Object{}};
|
Json info{Object{}};
|
||||||
|
|
||||||
#if defined(XGBOOST_BUILTIN_PREFETCH_PRESENT)
|
#if defined(XGBOOST_BUILTIN_PREFETCH_PRESENT)
|
||||||
@ -124,7 +124,10 @@ XGB_DLL int XGBRegisterLogCallback(void (*callback)(const char*)) {
|
|||||||
|
|
||||||
XGB_DLL int XGBSetGlobalConfig(const char* json_str) {
|
XGB_DLL int XGBSetGlobalConfig(const char* json_str) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
|
||||||
|
xgboost_CHECK_C_ARG_PTR(json_str);
|
||||||
Json config{Json::Load(StringView{json_str})};
|
Json config{Json::Load(StringView{json_str})};
|
||||||
|
|
||||||
for (auto& items : get<Object>(config)) {
|
for (auto& items : get<Object>(config)) {
|
||||||
switch (items.second.GetValue().Type()) {
|
switch (items.second.GetValue().Type()) {
|
||||||
case xgboost::Value::ValueKind::kInteger: {
|
case xgboost::Value::ValueKind::kInteger: {
|
||||||
@ -200,6 +203,8 @@ XGB_DLL int XGBGetGlobalConfig(const char** json_str) {
|
|||||||
|
|
||||||
auto& local = *GlobalConfigAPIThreadLocalStore::Get();
|
auto& local = *GlobalConfigAPIThreadLocalStore::Get();
|
||||||
Json::Dump(config, &local.ret_str);
|
Json::Dump(config, &local.ret_str);
|
||||||
|
|
||||||
|
xgboost_CHECK_C_ARG_PTR(json_str);
|
||||||
*json_str = local.ret_str.c_str();
|
*json_str = local.ret_str.c_str();
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
@ -216,6 +221,9 @@ XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle
|
|||||||
load_row_split = true;
|
load_row_split = true;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
xgboost_CHECK_C_ARG_PTR(fname);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out);
|
||||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Load(fname, silent != 0, load_row_split));
|
*out = new std::shared_ptr<DMatrix>(DMatrix::Load(fname, silent != 0, load_row_split));
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
@ -232,6 +240,7 @@ XGB_DLL int XGDMatrixCreateFromDataIter(
|
|||||||
}
|
}
|
||||||
xgboost::data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext,
|
xgboost::data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext,
|
||||||
XGBoostBatchCSR> adapter(data_handle, callback);
|
XGBoostBatchCSR> adapter(data_handle, callback);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out);
|
||||||
*out = new std::shared_ptr<DMatrix> {
|
*out = new std::shared_ptr<DMatrix> {
|
||||||
DMatrix::Create(
|
DMatrix::Create(
|
||||||
&adapter, std::numeric_limits<float>::quiet_NaN(),
|
&adapter, std::numeric_limits<float>::quiet_NaN(),
|
||||||
@ -265,10 +274,17 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
|
|||||||
DataIterResetCallback *reset, XGDMatrixCallbackNext *next,
|
DataIterResetCallback *reset, XGDMatrixCallbackNext *next,
|
||||||
char const *c_json_config, DMatrixHandle *out) {
|
char const *c_json_config, DMatrixHandle *out) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
xgboost_CHECK_C_ARG_PTR(c_json_config);
|
||||||
|
|
||||||
auto config = Json::Load(StringView{c_json_config});
|
auto config = Json::Load(StringView{c_json_config});
|
||||||
auto missing = GetMissing(config);
|
auto missing = GetMissing(config);
|
||||||
std::string cache = RequiredArg<String>(config, "cache_prefix", __func__);
|
std::string cache = RequiredArg<String>(config, "cache_prefix", __func__);
|
||||||
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", common::OmpGetNumThreads(0));
|
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", common::OmpGetNumThreads(0));
|
||||||
|
|
||||||
|
xgboost_CHECK_C_ARG_PTR(next);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(reset);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out);
|
||||||
|
|
||||||
*out = new std::shared_ptr<xgboost::DMatrix>{
|
*out = new std::shared_ptr<xgboost::DMatrix>{
|
||||||
xgboost::DMatrix::Create(iter, proxy, reset, next, missing, n_threads, cache)};
|
xgboost::DMatrix::Create(iter, proxy, reset, next, missing, n_threads, cache)};
|
||||||
API_END();
|
API_END();
|
||||||
@ -300,11 +316,16 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand
|
|||||||
CHECK(_ref) << err;
|
CHECK(_ref) << err;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
xgboost_CHECK_C_ARG_PTR(config);
|
||||||
auto jconfig = Json::Load(StringView{config});
|
auto jconfig = Json::Load(StringView{config});
|
||||||
auto missing = GetMissing(jconfig);
|
auto missing = GetMissing(jconfig);
|
||||||
auto n_threads = OptionalArg<Integer, int64_t>(jconfig, "nthread", common::OmpGetNumThreads(0));
|
auto n_threads = OptionalArg<Integer, int64_t>(jconfig, "nthread", common::OmpGetNumThreads(0));
|
||||||
auto max_bin = OptionalArg<Integer, int64_t>(jconfig, "max_bin", 256);
|
auto max_bin = OptionalArg<Integer, int64_t>(jconfig, "max_bin", 256);
|
||||||
|
|
||||||
|
xgboost_CHECK_C_ARG_PTR(next);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(reset);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out);
|
||||||
|
|
||||||
*out = new std::shared_ptr<xgboost::DMatrix>{
|
*out = new std::shared_ptr<xgboost::DMatrix>{
|
||||||
xgboost::DMatrix::Create(iter, proxy, _ref, reset, next, missing, n_threads, max_bin)};
|
xgboost::DMatrix::Create(iter, proxy, _ref, reset, next, missing, n_threads, max_bin)};
|
||||||
API_END();
|
API_END();
|
||||||
@ -312,6 +333,7 @@ XGB_DLL int XGQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHand
|
|||||||
|
|
||||||
XGB_DLL int XGProxyDMatrixCreate(DMatrixHandle* out) {
|
XGB_DLL int XGProxyDMatrixCreate(DMatrixHandle* out) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out);
|
||||||
*out = new std::shared_ptr<xgboost::DMatrix>(new xgboost::data::DMatrixProxy);;
|
*out = new std::shared_ptr<xgboost::DMatrix>(new xgboost::data::DMatrixProxy);;
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
@ -321,6 +343,7 @@ XGProxyDMatrixSetDataCudaArrayInterface(DMatrixHandle handle,
|
|||||||
char const *c_interface_str) {
|
char const *c_interface_str) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
|
xgboost_CHECK_C_ARG_PTR(c_interface_str);
|
||||||
auto p_m = static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
|
auto p_m = static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
|
||||||
CHECK(p_m);
|
CHECK(p_m);
|
||||||
auto m = static_cast<xgboost::data::DMatrixProxy*>(p_m->get());
|
auto m = static_cast<xgboost::data::DMatrixProxy*>(p_m->get());
|
||||||
@ -333,6 +356,7 @@ XGB_DLL int XGProxyDMatrixSetDataCudaColumnar(DMatrixHandle handle,
|
|||||||
char const *c_interface_str) {
|
char const *c_interface_str) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
|
xgboost_CHECK_C_ARG_PTR(c_interface_str);
|
||||||
auto p_m = static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
|
auto p_m = static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
|
||||||
CHECK(p_m);
|
CHECK(p_m);
|
||||||
auto m = static_cast<xgboost::data::DMatrixProxy*>(p_m->get());
|
auto m = static_cast<xgboost::data::DMatrixProxy*>(p_m->get());
|
||||||
@ -345,6 +369,7 @@ XGB_DLL int XGProxyDMatrixSetDataDense(DMatrixHandle handle,
|
|||||||
char const *c_interface_str) {
|
char const *c_interface_str) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
|
xgboost_CHECK_C_ARG_PTR(c_interface_str);
|
||||||
auto p_m = static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
|
auto p_m = static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
|
||||||
CHECK(p_m);
|
CHECK(p_m);
|
||||||
auto m = static_cast<xgboost::data::DMatrixProxy*>(p_m->get());
|
auto m = static_cast<xgboost::data::DMatrixProxy*>(p_m->get());
|
||||||
@ -358,6 +383,9 @@ XGB_DLL int XGProxyDMatrixSetDataCSR(DMatrixHandle handle, char const *indptr,
|
|||||||
xgboost::bst_ulong ncol) {
|
xgboost::bst_ulong ncol) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
|
xgboost_CHECK_C_ARG_PTR(indptr);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(indices);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(data);
|
||||||
auto p_m = static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
|
auto p_m = static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
|
||||||
CHECK(p_m);
|
CHECK(p_m);
|
||||||
auto m = static_cast<xgboost::data::DMatrixProxy*>(p_m->get());
|
auto m = static_cast<xgboost::data::DMatrixProxy*>(p_m->get());
|
||||||
@ -387,11 +415,16 @@ XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr,
|
|||||||
char const* c_json_config,
|
char const* c_json_config,
|
||||||
DMatrixHandle* out) {
|
DMatrixHandle* out) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
xgboost_CHECK_C_ARG_PTR(indptr);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(indices);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(data);
|
||||||
data::CSRArrayAdapter adapter(StringView{indptr}, StringView{indices},
|
data::CSRArrayAdapter adapter(StringView{indptr}, StringView{indices},
|
||||||
StringView{data}, ncol);
|
StringView{data}, ncol);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(c_json_config);
|
||||||
auto config = Json::Load(StringView{c_json_config});
|
auto config = Json::Load(StringView{c_json_config});
|
||||||
float missing = GetMissing(config);
|
float missing = GetMissing(config);
|
||||||
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", common::OmpGetNumThreads(0));
|
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", common::OmpGetNumThreads(0));
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out);
|
||||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
|
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
@ -400,11 +433,13 @@ XGB_DLL int XGDMatrixCreateFromDense(char const *data,
|
|||||||
char const *c_json_config,
|
char const *c_json_config,
|
||||||
DMatrixHandle *out) {
|
DMatrixHandle *out) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
xgboost::data::ArrayAdapter adapter{
|
xgboost_CHECK_C_ARG_PTR(data);
|
||||||
xgboost::data::ArrayAdapter(StringView{data})};
|
xgboost::data::ArrayAdapter adapter{xgboost::data::ArrayAdapter(StringView{data})};
|
||||||
|
xgboost_CHECK_C_ARG_PTR(c_json_config);
|
||||||
auto config = Json::Load(StringView{c_json_config});
|
auto config = Json::Load(StringView{c_json_config});
|
||||||
float missing = GetMissing(config);
|
float missing = GetMissing(config);
|
||||||
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", common::OmpGetNumThreads(0));
|
auto n_threads = OptionalArg<Integer, int64_t>(config, "nthread", common::OmpGetNumThreads(0));
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out);
|
||||||
*out =
|
*out =
|
||||||
new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
|
new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
|
||||||
API_END();
|
API_END();
|
||||||
@ -419,6 +454,7 @@ XGB_DLL int XGDMatrixCreateFromCSCEx(const size_t* col_ptr,
|
|||||||
DMatrixHandle* out) {
|
DMatrixHandle* out) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
data::CSCAdapter adapter(col_ptr, indices, data, nindptr - 1, num_row);
|
data::CSCAdapter adapter(col_ptr, indices, data, nindptr - 1, num_row);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out);
|
||||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, std::nan(""), 1));
|
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, std::nan(""), 1));
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
@ -429,6 +465,7 @@ XGB_DLL int XGDMatrixCreateFromMat(const bst_float* data,
|
|||||||
DMatrixHandle* out) {
|
DMatrixHandle* out) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
data::DenseAdapter adapter(data, nrow, ncol);
|
data::DenseAdapter adapter(data, nrow, ncol);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out);
|
||||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, 1));
|
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, 1));
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
@ -440,6 +477,7 @@ XGB_DLL int XGDMatrixCreateFromMat_omp(const bst_float* data, // NOLINT
|
|||||||
int nthread) {
|
int nthread) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
data::DenseAdapter adapter(data, nrow, ncol);
|
data::DenseAdapter adapter(data, nrow, ncol);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out);
|
||||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, nthread));
|
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, nthread));
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
@ -450,8 +488,8 @@ XGB_DLL int XGDMatrixCreateFromDT(void** data, const char** feature_stypes,
|
|||||||
int nthread) {
|
int nthread) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
data::DataTableAdapter adapter(data, feature_stypes, nrow, ncol);
|
data::DataTableAdapter adapter(data, feature_stypes, nrow, ncol);
|
||||||
*out = new std::shared_ptr<DMatrix>(
|
xgboost_CHECK_C_ARG_PTR(out);
|
||||||
DMatrix::Create(&adapter, std::nan(""), nthread));
|
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, std::nan(""), nthread));
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -467,11 +505,13 @@ XGB_DLL int XGImportArrowRecordBatch(DataIterHandle data_handle, void *ptr_array
|
|||||||
XGB_DLL int XGDMatrixCreateFromArrowCallback(XGDMatrixCallbackNext *next, char const *json_config,
|
XGB_DLL int XGDMatrixCreateFromArrowCallback(XGDMatrixCallbackNext *next, char const *json_config,
|
||||||
DMatrixHandle *out) {
|
DMatrixHandle *out) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
xgboost_CHECK_C_ARG_PTR(json_config);
|
||||||
auto config = Json::Load(StringView{json_config});
|
auto config = Json::Load(StringView{json_config});
|
||||||
auto missing = GetMissing(config);
|
auto missing = GetMissing(config);
|
||||||
int32_t n_threads = get<Integer const>(config["nthread"]);
|
int32_t n_threads = get<Integer const>(config["nthread"]);
|
||||||
n_threads = common::OmpGetNumThreads(n_threads);
|
n_threads = common::OmpGetNumThreads(n_threads);
|
||||||
data::RecordBatchesIterAdapter adapter(next, n_threads);
|
data::RecordBatchesIterAdapter adapter(next, n_threads);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out);
|
||||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
|
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, n_threads));
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
@ -480,6 +520,7 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
|
|||||||
const int* idxset,
|
const int* idxset,
|
||||||
xgboost::bst_ulong len,
|
xgboost::bst_ulong len,
|
||||||
DMatrixHandle* out) {
|
DMatrixHandle* out) {
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out);
|
||||||
return XGDMatrixSliceDMatrixEx(handle, idxset, len, out, 0);
|
return XGDMatrixSliceDMatrixEx(handle, idxset, len, out, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -516,6 +557,7 @@ XGB_DLL int XGDMatrixSaveBinary(DMatrixHandle handle, const char* fname,
|
|||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
auto dmat = static_cast<std::shared_ptr<DMatrix>*>(handle)->get();
|
auto dmat = static_cast<std::shared_ptr<DMatrix>*>(handle)->get();
|
||||||
|
xgboost_CHECK_C_ARG_PTR(fname);
|
||||||
if (data::SimpleDMatrix* derived = dynamic_cast<data::SimpleDMatrix*>(dmat)) {
|
if (data::SimpleDMatrix* derived = dynamic_cast<data::SimpleDMatrix*>(dmat)) {
|
||||||
derived->SaveToLocalFile(fname);
|
derived->SaveToLocalFile(fname);
|
||||||
} else {
|
} else {
|
||||||
@ -528,6 +570,7 @@ XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, const char *field, const
|
|||||||
xgboost::bst_ulong len) {
|
xgboost::bst_ulong len) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
|
xgboost_CHECK_C_ARG_PTR(field);
|
||||||
auto const& p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
auto const& p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
||||||
p_fmat->SetInfo(field, info, xgboost::DataType::kFloat32, len);
|
p_fmat->SetInfo(field, info, xgboost::DataType::kFloat32, len);
|
||||||
API_END();
|
API_END();
|
||||||
@ -537,6 +580,7 @@ XGB_DLL int XGDMatrixSetInfoFromInterface(DMatrixHandle handle, char const *fiel
|
|||||||
char const *interface_c_str) {
|
char const *interface_c_str) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
|
xgboost_CHECK_C_ARG_PTR(field);
|
||||||
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
||||||
p_fmat->SetInfo(field, interface_c_str);
|
p_fmat->SetInfo(field, interface_c_str);
|
||||||
API_END();
|
API_END();
|
||||||
@ -546,6 +590,7 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, const char *field, const
|
|||||||
xgboost::bst_ulong len) {
|
xgboost::bst_ulong len) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
|
xgboost_CHECK_C_ARG_PTR(field);
|
||||||
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
||||||
p_fmat->SetInfo(field, info, xgboost::DataType::kUInt32, len);
|
p_fmat->SetInfo(field, info, xgboost::DataType::kUInt32, len);
|
||||||
API_END();
|
API_END();
|
||||||
@ -557,6 +602,7 @@ XGB_DLL int XGDMatrixSetStrFeatureInfo(DMatrixHandle handle, const char *field,
|
|||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
auto &info = static_cast<std::shared_ptr<DMatrix> *>(handle)->get()->Info();
|
auto &info = static_cast<std::shared_ptr<DMatrix> *>(handle)->get()->Info();
|
||||||
|
xgboost_CHECK_C_ARG_PTR(field);
|
||||||
info.SetFeatureInfo(field, c_info, size);
|
info.SetFeatureInfo(field, c_info, size);
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
@ -572,12 +618,15 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field,
|
|||||||
std::vector<const char *> &charp_vecs = m->GetThreadLocal().ret_vec_charp;
|
std::vector<const char *> &charp_vecs = m->GetThreadLocal().ret_vec_charp;
|
||||||
std::vector<std::string> &str_vecs = m->GetThreadLocal().ret_vec_str;
|
std::vector<std::string> &str_vecs = m->GetThreadLocal().ret_vec_str;
|
||||||
|
|
||||||
|
xgboost_CHECK_C_ARG_PTR(field);
|
||||||
info.GetFeatureInfo(field, &str_vecs);
|
info.GetFeatureInfo(field, &str_vecs);
|
||||||
|
|
||||||
charp_vecs.resize(str_vecs.size());
|
charp_vecs.resize(str_vecs.size());
|
||||||
for (size_t i = 0; i < str_vecs.size(); ++i) {
|
for (size_t i = 0; i < str_vecs.size(); ++i) {
|
||||||
charp_vecs[i] = str_vecs[i].c_str();
|
charp_vecs[i] = str_vecs[i].c_str();
|
||||||
}
|
}
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_features);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(len);
|
||||||
*out_features = dmlc::BeginPtr(charp_vecs);
|
*out_features = dmlc::BeginPtr(charp_vecs);
|
||||||
*len = static_cast<xgboost::bst_ulong>(charp_vecs.size());
|
*len = static_cast<xgboost::bst_ulong>(charp_vecs.size());
|
||||||
API_END();
|
API_END();
|
||||||
@ -589,6 +638,7 @@ XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field, void
|
|||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
auto const &p_fmat = *static_cast<std::shared_ptr<DMatrix> *>(handle);
|
||||||
CHECK(type >= 1 && type <= 4);
|
CHECK(type >= 1 && type <= 4);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(field);
|
||||||
p_fmat->SetInfo(field, data, static_cast<DataType>(type), size);
|
p_fmat->SetInfo(field, data, static_cast<DataType>(type), size);
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
@ -608,7 +658,10 @@ XGB_DLL int XGDMatrixGetFloatInfo(const DMatrixHandle handle,
|
|||||||
const bst_float** out_dptr) {
|
const bst_float** out_dptr) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
|
xgboost_CHECK_C_ARG_PTR(field);
|
||||||
const MetaInfo& info = static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info();
|
const MetaInfo& info = static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info();
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_len);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_dptr);
|
||||||
info.GetInfo(field, out_len, DataType::kFloat32, reinterpret_cast<void const**>(out_dptr));
|
info.GetInfo(field, out_len, DataType::kFloat32, reinterpret_cast<void const**>(out_dptr));
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
@ -619,7 +672,10 @@ XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
|
|||||||
const unsigned **out_dptr) {
|
const unsigned **out_dptr) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
|
xgboost_CHECK_C_ARG_PTR(field);
|
||||||
const MetaInfo& info = static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info();
|
const MetaInfo& info = static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info();
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_len);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_dptr);
|
||||||
info.GetInfo(field, out_len, DataType::kUInt32, reinterpret_cast<void const**>(out_dptr));
|
info.GetInfo(field, out_len, DataType::kUInt32, reinterpret_cast<void const**>(out_dptr));
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
@ -628,6 +684,7 @@ XGB_DLL int XGDMatrixNumRow(const DMatrixHandle handle,
|
|||||||
xgboost::bst_ulong *out) {
|
xgboost::bst_ulong *out) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out);
|
||||||
*out = static_cast<xgboost::bst_ulong>(
|
*out = static_cast<xgboost::bst_ulong>(
|
||||||
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info().num_row_);
|
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info().num_row_);
|
||||||
API_END();
|
API_END();
|
||||||
@ -637,6 +694,7 @@ XGB_DLL int XGDMatrixNumCol(const DMatrixHandle handle,
|
|||||||
xgboost::bst_ulong *out) {
|
xgboost::bst_ulong *out) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out);
|
||||||
*out = static_cast<xgboost::bst_ulong>(
|
*out = static_cast<xgboost::bst_ulong>(
|
||||||
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info().num_col_);
|
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info().num_col_);
|
||||||
API_END();
|
API_END();
|
||||||
@ -649,8 +707,10 @@ XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[],
|
|||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
std::vector<std::shared_ptr<DMatrix> > mats;
|
std::vector<std::shared_ptr<DMatrix> > mats;
|
||||||
for (xgboost::bst_ulong i = 0; i < len; ++i) {
|
for (xgboost::bst_ulong i = 0; i < len; ++i) {
|
||||||
|
xgboost_CHECK_C_ARG_PTR(dmats);
|
||||||
mats.push_back(*static_cast<std::shared_ptr<DMatrix>*>(dmats[i]));
|
mats.push_back(*static_cast<std::shared_ptr<DMatrix>*>(dmats[i]));
|
||||||
}
|
}
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out);
|
||||||
*out = Learner::Create(mats);
|
*out = Learner::Create(mats);
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
@ -676,6 +736,7 @@ XGB_DLL int XGBoosterGetNumFeature(BoosterHandle handle,
|
|||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
static_cast<Learner*>(handle)->Configure();
|
static_cast<Learner*>(handle)->Configure();
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out);
|
||||||
*out = static_cast<Learner*>(handle)->GetNumFeature();
|
*out = static_cast<Learner*>(handle)->GetNumFeature();
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
@ -684,6 +745,7 @@ XGB_DLL int XGBoosterBoostedRounds(BoosterHandle handle, int* out) {
|
|||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
static_cast<Learner*>(handle)->Configure();
|
static_cast<Learner*>(handle)->Configure();
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out);
|
||||||
*out = static_cast<Learner*>(handle)->BoostedRounds();
|
*out = static_cast<Learner*>(handle)->BoostedRounds();
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
@ -691,6 +753,7 @@ XGB_DLL int XGBoosterBoostedRounds(BoosterHandle handle, int* out) {
|
|||||||
XGB_DLL int XGBoosterLoadJsonConfig(BoosterHandle handle, char const* json_parameters) {
|
XGB_DLL int XGBoosterLoadJsonConfig(BoosterHandle handle, char const* json_parameters) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
|
xgboost_CHECK_C_ARG_PTR(json_parameters);
|
||||||
Json config { Json::Load(StringView{json_parameters}) };
|
Json config { Json::Load(StringView{json_parameters}) };
|
||||||
static_cast<Learner*>(handle)->LoadConfig(config);
|
static_cast<Learner*>(handle)->LoadConfig(config);
|
||||||
API_END();
|
API_END();
|
||||||
@ -707,6 +770,10 @@ XGB_DLL int XGBoosterSaveJsonConfig(BoosterHandle handle,
|
|||||||
learner->SaveConfig(&config);
|
learner->SaveConfig(&config);
|
||||||
std::string& raw_str = learner->GetThreadLocal().ret_str;
|
std::string& raw_str = learner->GetThreadLocal().ret_str;
|
||||||
Json::Dump(config, &raw_str);
|
Json::Dump(config, &raw_str);
|
||||||
|
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_str);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_len);
|
||||||
|
|
||||||
*out_str = raw_str.c_str();
|
*out_str = raw_str.c_str();
|
||||||
*out_len = static_cast<xgboost::bst_ulong>(raw_str.length());
|
*out_len = static_cast<xgboost::bst_ulong>(raw_str.length());
|
||||||
API_END();
|
API_END();
|
||||||
@ -718,9 +785,9 @@ XGB_DLL int XGBoosterUpdateOneIter(BoosterHandle handle,
|
|||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
auto* bst = static_cast<Learner*>(handle);
|
auto* bst = static_cast<Learner*>(handle);
|
||||||
auto *dtr =
|
xgboost_CHECK_C_ARG_PTR(dtrain);
|
||||||
static_cast<std::shared_ptr<DMatrix>*>(dtrain);
|
auto *dtr = static_cast<std::shared_ptr<DMatrix> *>(dtrain);
|
||||||
|
CHECK(dtr);
|
||||||
bst->UpdateOneIter(iter, *dtr);
|
bst->UpdateOneIter(iter, *dtr);
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
@ -738,6 +805,10 @@ XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle,
|
|||||||
static_cast<std::shared_ptr<DMatrix>*>(dtrain);
|
static_cast<std::shared_ptr<DMatrix>*>(dtrain);
|
||||||
tmp_gpair.Resize(len);
|
tmp_gpair.Resize(len);
|
||||||
std::vector<GradientPair>& tmp_gpair_h = tmp_gpair.HostVector();
|
std::vector<GradientPair>& tmp_gpair_h = tmp_gpair.HostVector();
|
||||||
|
if (len > 0) {
|
||||||
|
xgboost_CHECK_C_ARG_PTR(grad);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(hess);
|
||||||
|
}
|
||||||
for (xgboost::bst_ulong i = 0; i < len; ++i) {
|
for (xgboost::bst_ulong i = 0; i < len; ++i) {
|
||||||
tmp_gpair_h[i] = GradientPair(grad[i], hess[i]);
|
tmp_gpair_h[i] = GradientPair(grad[i], hess[i]);
|
||||||
}
|
}
|
||||||
@ -761,11 +832,14 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle,
|
|||||||
std::vector<std::string> data_names;
|
std::vector<std::string> data_names;
|
||||||
|
|
||||||
for (xgboost::bst_ulong i = 0; i < len; ++i) {
|
for (xgboost::bst_ulong i = 0; i < len; ++i) {
|
||||||
|
xgboost_CHECK_C_ARG_PTR(dmats);
|
||||||
data_sets.push_back(*static_cast<std::shared_ptr<DMatrix>*>(dmats[i]));
|
data_sets.push_back(*static_cast<std::shared_ptr<DMatrix>*>(dmats[i]));
|
||||||
|
xgboost_CHECK_C_ARG_PTR(evnames);
|
||||||
data_names.emplace_back(evnames[i]);
|
data_names.emplace_back(evnames[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
eval_str = bst->EvalOneIter(iter, data_sets, data_names);
|
eval_str = bst->EvalOneIter(iter, data_sets, data_names);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_str);
|
||||||
*out_str = eval_str.c_str();
|
*out_str = eval_str.c_str();
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
@ -787,6 +861,10 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
|||||||
static_cast<bool>(training), (option_mask & 2) != 0,
|
static_cast<bool>(training), (option_mask & 2) != 0,
|
||||||
(option_mask & 4) != 0, (option_mask & 8) != 0,
|
(option_mask & 4) != 0, (option_mask & 8) != 0,
|
||||||
(option_mask & 16) != 0);
|
(option_mask & 16) != 0);
|
||||||
|
|
||||||
|
xgboost_CHECK_C_ARG_PTR(len);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_result);
|
||||||
|
|
||||||
*out_result = dmlc::BeginPtr(entry.predictions.ConstHostVector());
|
*out_result = dmlc::BeginPtr(entry.predictions.ConstHostVector());
|
||||||
*len = static_cast<xgboost::bst_ulong>(entry.predictions.Size());
|
*len = static_cast<xgboost::bst_ulong>(entry.predictions.Size());
|
||||||
API_END();
|
API_END();
|
||||||
@ -805,6 +883,7 @@ XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle,
|
|||||||
if (dmat == nullptr) {
|
if (dmat == nullptr) {
|
||||||
LOG(FATAL) << "DMatrix has not been initialized or has already been disposed.";
|
LOG(FATAL) << "DMatrix has not been initialized or has already been disposed.";
|
||||||
}
|
}
|
||||||
|
xgboost_CHECK_C_ARG_PTR(c_json_config);
|
||||||
auto config = Json::Load(StringView{c_json_config});
|
auto config = Json::Load(StringView{c_json_config});
|
||||||
|
|
||||||
auto *learner = static_cast<Learner*>(handle);
|
auto *learner = static_cast<Learner*>(handle);
|
||||||
@ -836,13 +915,20 @@ XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle,
|
|||||||
iteration_begin, iteration_end, training,
|
iteration_begin, iteration_end, training,
|
||||||
type == PredictionType::kLeaf, contribs, approximate,
|
type == PredictionType::kLeaf, contribs, approximate,
|
||||||
interactions);
|
interactions);
|
||||||
|
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_result);
|
||||||
*out_result = dmlc::BeginPtr(entry.predictions.ConstHostVector());
|
*out_result = dmlc::BeginPtr(entry.predictions.ConstHostVector());
|
||||||
|
|
||||||
auto &shape = learner->GetThreadLocal().prediction_shape;
|
auto &shape = learner->GetThreadLocal().prediction_shape;
|
||||||
auto chunksize = p_m->Info().num_row_ == 0 ? 0 : entry.predictions.Size() / p_m->Info().num_row_;
|
auto chunksize = p_m->Info().num_row_ == 0 ? 0 : entry.predictions.Size() / p_m->Info().num_row_;
|
||||||
auto rounds = iteration_end - iteration_begin;
|
auto rounds = iteration_end - iteration_begin;
|
||||||
rounds = rounds == 0 ? learner->BoostedRounds() : rounds;
|
rounds = rounds == 0 ? learner->BoostedRounds() : rounds;
|
||||||
// Determine shape
|
// Determine shape
|
||||||
bool strict_shape = RequiredArg<Boolean>(config, "strict_shape", __func__);
|
bool strict_shape = RequiredArg<Boolean>(config, "strict_shape", __func__);
|
||||||
|
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_dim);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_shape);
|
||||||
|
|
||||||
CalcPredictShape(strict_shape, type, p_m->Info().num_row_,
|
CalcPredictShape(strict_shape, type, p_m->Info().num_row_,
|
||||||
p_m->Info().num_col_, chunksize, learner->Groups(), rounds,
|
p_m->Info().num_col_, chunksize, learner->Groups(), rounds,
|
||||||
&shape, out_dim);
|
&shape, out_dim);
|
||||||
@ -853,6 +939,7 @@ XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle,
|
|||||||
void InplacePredictImpl(std::shared_ptr<DMatrix> p_m, char const *c_json_config, Learner *learner,
|
void InplacePredictImpl(std::shared_ptr<DMatrix> p_m, char const *c_json_config, Learner *learner,
|
||||||
xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim,
|
xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim,
|
||||||
const float **out_result) {
|
const float **out_result) {
|
||||||
|
xgboost_CHECK_C_ARG_PTR(c_json_config);
|
||||||
auto config = Json::Load(StringView{c_json_config});
|
auto config = Json::Load(StringView{c_json_config});
|
||||||
CHECK_EQ(get<Integer const>(config["cache_id"]), 0) << "Cache ID is not supported yet";
|
CHECK_EQ(get<Integer const>(config["cache_id"]), 0) << "Cache ID is not supported yet";
|
||||||
|
|
||||||
@ -869,8 +956,14 @@ void InplacePredictImpl(std::shared_ptr<DMatrix> p_m, char const *c_json_config,
|
|||||||
auto n_features = info.num_col_;
|
auto n_features = info.num_col_;
|
||||||
auto chunksize = n_samples == 0 ? 0 : p_predt->Size() / n_samples;
|
auto chunksize = n_samples == 0 ? 0 : p_predt->Size() / n_samples;
|
||||||
bool strict_shape = RequiredArg<Boolean>(config, "strict_shape", __func__);
|
bool strict_shape = RequiredArg<Boolean>(config, "strict_shape", __func__);
|
||||||
|
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_dim);
|
||||||
CalcPredictShape(strict_shape, type, n_samples, n_features, chunksize, learner->Groups(),
|
CalcPredictShape(strict_shape, type, n_samples, n_features, chunksize, learner->Groups(),
|
||||||
learner->BoostedRounds(), &shape, out_dim);
|
learner->BoostedRounds(), &shape, out_dim);
|
||||||
|
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_result);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_shape);
|
||||||
|
|
||||||
*out_result = dmlc::BeginPtr(p_predt->HostVector());
|
*out_result = dmlc::BeginPtr(p_predt->HostVector());
|
||||||
*out_shape = dmlc::BeginPtr(shape);
|
*out_shape = dmlc::BeginPtr(shape);
|
||||||
}
|
}
|
||||||
@ -889,6 +982,7 @@ XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, char const *array_in
|
|||||||
}
|
}
|
||||||
auto proxy = dynamic_cast<data::DMatrixProxy *>(p_m.get());
|
auto proxy = dynamic_cast<data::DMatrixProxy *>(p_m.get());
|
||||||
CHECK(proxy) << "Invalid input type for inplace predict.";
|
CHECK(proxy) << "Invalid input type for inplace predict.";
|
||||||
|
xgboost_CHECK_C_ARG_PTR(array_interface);
|
||||||
proxy->SetArrayData(array_interface);
|
proxy->SetArrayData(array_interface);
|
||||||
auto *learner = static_cast<xgboost::Learner *>(handle);
|
auto *learner = static_cast<xgboost::Learner *>(handle);
|
||||||
InplacePredictImpl(p_m, c_json_config, learner, out_shape, out_dim, out_result);
|
InplacePredictImpl(p_m, c_json_config, learner, out_shape, out_dim, out_result);
|
||||||
@ -910,6 +1004,7 @@ XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, char const *indptr, ch
|
|||||||
}
|
}
|
||||||
auto proxy = dynamic_cast<data::DMatrixProxy *>(p_m.get());
|
auto proxy = dynamic_cast<data::DMatrixProxy *>(p_m.get());
|
||||||
CHECK(proxy) << "Invalid input type for inplace predict.";
|
CHECK(proxy) << "Invalid input type for inplace predict.";
|
||||||
|
xgboost_CHECK_C_ARG_PTR(indptr);
|
||||||
proxy->SetCSRData(indptr, indices, data, cols, true);
|
proxy->SetCSRData(indptr, indices, data, cols, true);
|
||||||
auto *learner = static_cast<xgboost::Learner *>(handle);
|
auto *learner = static_cast<xgboost::Learner *>(handle);
|
||||||
InplacePredictImpl(p_m, c_json_config, learner, out_shape, out_dim, out_result);
|
InplacePredictImpl(p_m, c_json_config, learner, out_shape, out_dim, out_result);
|
||||||
@ -941,6 +1036,7 @@ XGB_DLL int XGBoosterPredictFromCUDAColumnar(
|
|||||||
XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
|
XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
|
xgboost_CHECK_C_ARG_PTR(fname);
|
||||||
auto read_file = [&]() {
|
auto read_file = [&]() {
|
||||||
auto str = common::LoadSequentialFile(fname);
|
auto str = common::LoadSequentialFile(fname);
|
||||||
CHECK_GE(str.size(), 3); // "{}\0"
|
CHECK_GE(str.size(), 3); // "{}\0"
|
||||||
@ -971,10 +1067,12 @@ void WarnOldModel() {
|
|||||||
}
|
}
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char *c_fname) {
|
XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char *fname) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(c_fname, "w"));
|
xgboost_CHECK_C_ARG_PTR(fname);
|
||||||
|
|
||||||
|
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname, "w"));
|
||||||
auto *learner = static_cast<Learner *>(handle);
|
auto *learner = static_cast<Learner *>(handle);
|
||||||
learner->Configure();
|
learner->Configure();
|
||||||
auto save_json = [&](std::ios::openmode mode) {
|
auto save_json = [&](std::ios::openmode mode) {
|
||||||
@ -984,9 +1082,9 @@ XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char *c_fname) {
|
|||||||
Json::Dump(out, &str, mode);
|
Json::Dump(out, &str, mode);
|
||||||
fo->Write(str.data(), str.size());
|
fo->Write(str.data(), str.size());
|
||||||
};
|
};
|
||||||
if (common::FileExtension(c_fname) == "json") {
|
if (common::FileExtension(fname) == "json") {
|
||||||
save_json(std::ios::out);
|
save_json(std::ios::out);
|
||||||
} else if (common::FileExtension(c_fname) == "ubj") {
|
} else if (common::FileExtension(fname) == "ubj") {
|
||||||
save_json(std::ios::binary);
|
save_json(std::ios::binary);
|
||||||
} else if (XGBOOST_VER_MAJOR == 2 && XGBOOST_VER_MINOR >= 2) {
|
} else if (XGBOOST_VER_MAJOR == 2 && XGBOOST_VER_MINOR >= 2) {
|
||||||
LOG(WARNING) << "Saving model to JSON as default. You can use file extension `json`, `ubj` or "
|
LOG(WARNING) << "Saving model to JSON as default. You can use file extension `json`, `ubj` or "
|
||||||
@ -1004,6 +1102,8 @@ XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle, const void *buf,
|
|||||||
xgboost::bst_ulong len) {
|
xgboost::bst_ulong len) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
|
xgboost_CHECK_C_ARG_PTR(buf);
|
||||||
|
|
||||||
common::MemoryFixSizeBuffer fs((void *)buf, len); // NOLINT(*)
|
common::MemoryFixSizeBuffer fs((void *)buf, len); // NOLINT(*)
|
||||||
static_cast<Learner *>(handle)->LoadModel(&fs);
|
static_cast<Learner *>(handle)->LoadModel(&fs);
|
||||||
API_END();
|
API_END();
|
||||||
@ -1013,6 +1113,11 @@ XGB_DLL int XGBoosterSaveModelToBuffer(BoosterHandle handle, char const *json_co
|
|||||||
xgboost::bst_ulong *out_len, char const **out_dptr) {
|
xgboost::bst_ulong *out_len, char const **out_dptr) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
|
|
||||||
|
xgboost_CHECK_C_ARG_PTR(json_config);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_dptr);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_len);
|
||||||
|
|
||||||
auto config = Json::Load(StringView{json_config});
|
auto config = Json::Load(StringView{json_config});
|
||||||
auto format = RequiredArg<String>(config, "format", __func__);
|
auto format = RequiredArg<String>(config, "format", __func__);
|
||||||
|
|
||||||
@ -1039,6 +1144,7 @@ XGB_DLL int XGBoosterSaveModelToBuffer(BoosterHandle handle, char const *json_co
|
|||||||
raw_str.clear();
|
raw_str.clear();
|
||||||
common::MemoryBufferStream fo(&raw_str);
|
common::MemoryBufferStream fo(&raw_str);
|
||||||
learner->SaveModel(&fo);
|
learner->SaveModel(&fo);
|
||||||
|
|
||||||
*out_dptr = dmlc::BeginPtr(raw_str);
|
*out_dptr = dmlc::BeginPtr(raw_str);
|
||||||
*out_len = static_cast<xgboost::bst_ulong>(raw_str.size());
|
*out_len = static_cast<xgboost::bst_ulong>(raw_str.size());
|
||||||
} else {
|
} else {
|
||||||
@ -1048,11 +1154,11 @@ XGB_DLL int XGBoosterSaveModelToBuffer(BoosterHandle handle, char const *json_co
|
|||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle,
|
XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle, xgboost::bst_ulong *out_len,
|
||||||
xgboost::bst_ulong* out_len,
|
|
||||||
const char **out_dptr) {
|
const char **out_dptr) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
|
|
||||||
auto *learner = static_cast<Learner*>(handle);
|
auto *learner = static_cast<Learner*>(handle);
|
||||||
std::string& raw_str = learner->GetThreadLocal().ret_str;
|
std::string& raw_str = learner->GetThreadLocal().ret_str;
|
||||||
raw_str.resize(0);
|
raw_str.resize(0);
|
||||||
@ -1063,6 +1169,10 @@ XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle,
|
|||||||
|
|
||||||
learner->Configure();
|
learner->Configure();
|
||||||
learner->SaveModel(&fo);
|
learner->SaveModel(&fo);
|
||||||
|
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_dptr);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_len);
|
||||||
|
|
||||||
*out_dptr = dmlc::BeginPtr(raw_str);
|
*out_dptr = dmlc::BeginPtr(raw_str);
|
||||||
*out_len = static_cast<xgboost::bst_ulong>(raw_str.length());
|
*out_len = static_cast<xgboost::bst_ulong>(raw_str.length());
|
||||||
API_END();
|
API_END();
|
||||||
@ -1070,17 +1180,21 @@ XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle,
|
|||||||
|
|
||||||
// The following two functions are `Load` and `Save` for memory based
|
// The following two functions are `Load` and `Save` for memory based
|
||||||
// serialization methods. E.g. Python pickle.
|
// serialization methods. E.g. Python pickle.
|
||||||
XGB_DLL int XGBoosterSerializeToBuffer(BoosterHandle handle,
|
XGB_DLL int XGBoosterSerializeToBuffer(BoosterHandle handle, xgboost::bst_ulong *out_len,
|
||||||
xgboost::bst_ulong *out_len,
|
|
||||||
const char **out_dptr) {
|
const char **out_dptr) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
|
|
||||||
auto *learner = static_cast<Learner *>(handle);
|
auto *learner = static_cast<Learner *>(handle);
|
||||||
std::string &raw_str = learner->GetThreadLocal().ret_str;
|
std::string &raw_str = learner->GetThreadLocal().ret_str;
|
||||||
raw_str.resize(0);
|
raw_str.resize(0);
|
||||||
common::MemoryBufferStream fo(&raw_str);
|
common::MemoryBufferStream fo(&raw_str);
|
||||||
learner->Configure();
|
learner->Configure();
|
||||||
learner->Save(&fo);
|
learner->Save(&fo);
|
||||||
|
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_dptr);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_len);
|
||||||
|
|
||||||
*out_dptr = dmlc::BeginPtr(raw_str);
|
*out_dptr = dmlc::BeginPtr(raw_str);
|
||||||
*out_len = static_cast<xgboost::bst_ulong>(raw_str.length());
|
*out_len = static_cast<xgboost::bst_ulong>(raw_str.length());
|
||||||
API_END();
|
API_END();
|
||||||
@ -1091,6 +1205,8 @@ XGB_DLL int XGBoosterUnserializeFromBuffer(BoosterHandle handle,
|
|||||||
xgboost::bst_ulong len) {
|
xgboost::bst_ulong len) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
|
xgboost_CHECK_C_ARG_PTR(buf);
|
||||||
|
|
||||||
common::MemoryFixSizeBuffer fs((void*)buf, len); // NOLINT(*)
|
common::MemoryFixSizeBuffer fs((void*)buf, len); // NOLINT(*)
|
||||||
static_cast<Learner*>(handle)->Load(&fs);
|
static_cast<Learner*>(handle)->Load(&fs);
|
||||||
API_END();
|
API_END();
|
||||||
@ -1101,6 +1217,7 @@ XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
|
|||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
auto* bst = static_cast<Learner*>(handle);
|
auto* bst = static_cast<Learner*>(handle);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(version);
|
||||||
*version = rabit::LoadCheckPoint();
|
*version = rabit::LoadCheckPoint();
|
||||||
if (*version != 0) {
|
if (*version != 0) {
|
||||||
bst->Configure();
|
bst->Configure();
|
||||||
@ -1122,6 +1239,8 @@ XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer,
|
|||||||
BoosterHandle *out) {
|
BoosterHandle *out) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out);
|
||||||
|
|
||||||
auto* learner = static_cast<Learner*>(handle);
|
auto* learner = static_cast<Learner*>(handle);
|
||||||
bool out_of_bound = false;
|
bool out_of_bound = false;
|
||||||
auto p_out = learner->Slice(begin_layer, end_layer, step, &out_of_bound);
|
auto p_out = learner->Slice(begin_layer, end_layer, step, &out_of_bound);
|
||||||
@ -1148,6 +1267,10 @@ inline void XGBoostDumpModelImpl(BoosterHandle handle, FeatureMap* fmap,
|
|||||||
for (size_t i = 0; i < str_vecs.size(); ++i) {
|
for (size_t i = 0; i < str_vecs.size(); ++i) {
|
||||||
charp_vecs[i] = str_vecs[i].c_str();
|
charp_vecs[i] = str_vecs[i].c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_models);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(len);
|
||||||
|
|
||||||
*out_models = dmlc::BeginPtr(charp_vecs);
|
*out_models = dmlc::BeginPtr(charp_vecs);
|
||||||
*len = static_cast<xgboost::bst_ulong>(charp_vecs.size());
|
*len = static_cast<xgboost::bst_ulong>(charp_vecs.size());
|
||||||
}
|
}
|
||||||
@ -1171,6 +1294,8 @@ XGB_DLL int XGBoosterDumpModelEx(BoosterHandle handle,
|
|||||||
const char*** out_models) {
|
const char*** out_models) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
|
|
||||||
|
xgboost_CHECK_C_ARG_PTR(fmap);
|
||||||
std::string uri{fmap};
|
std::string uri{fmap};
|
||||||
FeatureMap featmap = LoadFeatureMap(uri);
|
FeatureMap featmap = LoadFeatureMap(uri);
|
||||||
XGBoostDumpModelImpl(handle, &featmap, with_stats, format, len, out_models);
|
XGBoostDumpModelImpl(handle, &featmap, with_stats, format, len, out_models);
|
||||||
@ -1200,20 +1325,24 @@ XGB_DLL int XGBoosterDumpModelExWithFeatures(BoosterHandle handle,
|
|||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
FeatureMap featmap;
|
FeatureMap featmap;
|
||||||
for (int i = 0; i < fnum; ++i) {
|
for (int i = 0; i < fnum; ++i) {
|
||||||
|
xgboost_CHECK_C_ARG_PTR(fname);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(ftype);
|
||||||
featmap.PushBack(i, fname[i], ftype[i]);
|
featmap.PushBack(i, fname[i], ftype[i]);
|
||||||
}
|
}
|
||||||
XGBoostDumpModelImpl(handle, &featmap, with_stats, format, len, out_models);
|
XGBoostDumpModelImpl(handle, &featmap, with_stats, format, len, out_models);
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL int XGBoosterGetAttr(BoosterHandle handle,
|
XGB_DLL int XGBoosterGetAttr(BoosterHandle handle, const char *key, const char **out,
|
||||||
const char* key,
|
|
||||||
const char** out,
|
|
||||||
int *success) {
|
int *success) {
|
||||||
auto* bst = static_cast<Learner*>(handle);
|
auto* bst = static_cast<Learner*>(handle);
|
||||||
std::string& ret_str = bst->GetThreadLocal().ret_str;
|
std::string& ret_str = bst->GetThreadLocal().ret_str;
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
|
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(success);
|
||||||
|
|
||||||
if (bst->GetAttr(key, &ret_str)) {
|
if (bst->GetAttr(key, &ret_str)) {
|
||||||
*out = ret_str.c_str();
|
*out = ret_str.c_str();
|
||||||
*success = 1;
|
*success = 1;
|
||||||
@ -1230,9 +1359,11 @@ XGB_DLL int XGBoosterSetAttr(BoosterHandle handle,
|
|||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
auto* bst = static_cast<Learner*>(handle);
|
auto* bst = static_cast<Learner*>(handle);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(key);
|
||||||
if (value == nullptr) {
|
if (value == nullptr) {
|
||||||
bst->DelAttr(key);
|
bst->DelAttr(key);
|
||||||
} else {
|
} else {
|
||||||
|
xgboost_CHECK_C_ARG_PTR(value);
|
||||||
bst->SetAttr(key, value);
|
bst->SetAttr(key, value);
|
||||||
}
|
}
|
||||||
API_END();
|
API_END();
|
||||||
@ -1243,6 +1374,7 @@ XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle,
|
|||||||
const char*** out) {
|
const char*** out) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
|
|
||||||
auto *learner = static_cast<Learner *>(handle);
|
auto *learner = static_cast<Learner *>(handle);
|
||||||
std::vector<std::string> &str_vecs = learner->GetThreadLocal().ret_vec_str;
|
std::vector<std::string> &str_vecs = learner->GetThreadLocal().ret_vec_str;
|
||||||
std::vector<const char *> &charp_vecs =
|
std::vector<const char *> &charp_vecs =
|
||||||
@ -1252,6 +1384,10 @@ XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle,
|
|||||||
for (size_t i = 0; i < str_vecs.size(); ++i) {
|
for (size_t i = 0; i < str_vecs.size(); ++i) {
|
||||||
charp_vecs[i] = str_vecs[i].c_str();
|
charp_vecs[i] = str_vecs[i].c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_len);
|
||||||
|
|
||||||
*out = dmlc::BeginPtr(charp_vecs);
|
*out = dmlc::BeginPtr(charp_vecs);
|
||||||
*out_len = static_cast<xgboost::bst_ulong>(charp_vecs.size());
|
*out_len = static_cast<xgboost::bst_ulong>(charp_vecs.size());
|
||||||
API_END();
|
API_END();
|
||||||
@ -1264,9 +1400,14 @@ XGB_DLL int XGBoosterSetStrFeatureInfo(BoosterHandle handle, const char *field,
|
|||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
auto *learner = static_cast<Learner *>(handle);
|
auto *learner = static_cast<Learner *>(handle);
|
||||||
std::vector<std::string> feature_info;
|
std::vector<std::string> feature_info;
|
||||||
|
if (size > 0) {
|
||||||
|
xgboost_CHECK_C_ARG_PTR(features);
|
||||||
|
}
|
||||||
for (size_t i = 0; i < size; ++i) {
|
for (size_t i = 0; i < size; ++i) {
|
||||||
feature_info.emplace_back(features[i]);
|
feature_info.emplace_back(features[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
xgboost_CHECK_C_ARG_PTR(field);
|
||||||
if (!std::strcmp(field, "feature_name")) {
|
if (!std::strcmp(field, "feature_name")) {
|
||||||
learner->SetFeatureNames(feature_info);
|
learner->SetFeatureNames(feature_info);
|
||||||
} else if (!std::strcmp(field, "feature_type")) {
|
} else if (!std::strcmp(field, "feature_type")) {
|
||||||
@ -1297,20 +1438,23 @@ XGB_DLL int XGBoosterGetStrFeatureInfo(BoosterHandle handle, const char *field,
|
|||||||
for (size_t i = 0; i < str_vecs.size(); ++i) {
|
for (size_t i = 0; i < str_vecs.size(); ++i) {
|
||||||
charp_vecs[i] = str_vecs[i].c_str();
|
charp_vecs[i] = str_vecs[i].c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_features);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(len);
|
||||||
|
|
||||||
*out_features = dmlc::BeginPtr(charp_vecs);
|
*out_features = dmlc::BeginPtr(charp_vecs);
|
||||||
*len = static_cast<xgboost::bst_ulong>(charp_vecs.size());
|
*len = static_cast<xgboost::bst_ulong>(charp_vecs.size());
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *json_config,
|
XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *json_config,
|
||||||
xgboost::bst_ulong *out_n_features,
|
xgboost::bst_ulong *out_n_features, char const ***out_features,
|
||||||
char const ***out_features,
|
bst_ulong *out_dim, bst_ulong const **out_shape,
|
||||||
bst_ulong *out_dim,
|
|
||||||
bst_ulong const **out_shape,
|
|
||||||
float const **out_scores) {
|
float const **out_scores) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
auto *learner = static_cast<Learner *>(handle);
|
auto *learner = static_cast<Learner *>(handle);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(json_config);
|
||||||
auto config = Json::Load(StringView{json_config});
|
auto config = Json::Load(StringView{json_config});
|
||||||
|
|
||||||
auto importance = RequiredArg<String>(config, "importance_type", __func__);
|
auto importance = RequiredArg<String>(config, "importance_type", __func__);
|
||||||
@ -1348,10 +1492,13 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *json_config,
|
|||||||
feature_names[i] = feature_map.Name(features[i]);
|
feature_names[i] = feature_map.Name(features[i]);
|
||||||
feature_names_c[i] = feature_names[i].data();
|
feature_names_c[i] = feature_names[i].data();
|
||||||
}
|
}
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_n_features);
|
||||||
*out_n_features = feature_names.size();
|
*out_n_features = feature_names.size();
|
||||||
|
|
||||||
CHECK_LE(features.size(), scores.size());
|
CHECK_LE(features.size(), scores.size());
|
||||||
auto &shape = learner->GetThreadLocal().prediction_shape;
|
auto &shape = learner->GetThreadLocal().prediction_shape;
|
||||||
|
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_dim);
|
||||||
if (scores.size() > features.size()) {
|
if (scores.size() > features.size()) {
|
||||||
// Linear model multi-class model
|
// Linear model multi-class model
|
||||||
CHECK_EQ(scores.size() % features.size(), 0ul);
|
CHECK_EQ(scores.size() % features.size(), 0ul);
|
||||||
@ -1365,6 +1512,10 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *json_config,
|
|||||||
shape.front() = scores.size();
|
shape.front() = scores.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_shape);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_scores);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_features);
|
||||||
|
|
||||||
*out_shape = dmlc::BeginPtr(shape);
|
*out_shape = dmlc::BeginPtr(shape);
|
||||||
*out_scores = scores.data();
|
*out_scores = scores.data();
|
||||||
*out_features = dmlc::BeginPtr(feature_names_c);
|
*out_features = dmlc::BeginPtr(feature_names_c);
|
||||||
@ -1375,26 +1526,27 @@ using xgboost::collective::Communicator;
|
|||||||
|
|
||||||
XGB_DLL int XGCommunicatorInit(char const* json_config) {
|
XGB_DLL int XGCommunicatorInit(char const* json_config) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
xgboost_CHECK_C_ARG_PTR(json_config);
|
||||||
Json config { Json::Load(StringView{json_config}) };
|
Json config { Json::Load(StringView{json_config}) };
|
||||||
Communicator::Init(config);
|
Communicator::Init(config);
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL int XGCommunicatorFinalize(void) {
|
XGB_DLL int XGCommunicatorFinalize() {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
Communicator::Finalize();
|
Communicator::Finalize();
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL int XGCommunicatorGetRank(void) {
|
XGB_DLL int XGCommunicatorGetRank() {
|
||||||
return Communicator::Get()->GetRank();
|
return Communicator::Get()->GetRank();
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL int XGCommunicatorGetWorldSize(void) {
|
XGB_DLL int XGCommunicatorGetWorldSize() {
|
||||||
return Communicator::Get()->GetWorldSize();
|
return Communicator::Get()->GetWorldSize();
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL int XGCommunicatorIsDistributed(void) {
|
XGB_DLL int XGCommunicatorIsDistributed() {
|
||||||
return Communicator::Get()->IsDistributed();
|
return Communicator::Get()->IsDistributed();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1408,6 +1560,7 @@ XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) {
|
|||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
auto& local = *GlobalConfigAPIThreadLocalStore::Get();
|
auto& local = *GlobalConfigAPIThreadLocalStore::Get();
|
||||||
local.ret_str = Communicator::Get()->GetProcessorName();
|
local.ret_str = Communicator::Get()->GetProcessorName();
|
||||||
|
xgboost_CHECK_C_ARG_PTR(name_str);
|
||||||
*name_str = local.ret_str.c_str();
|
*name_str = local.ret_str.c_str();
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -62,8 +62,13 @@ XGB_DLL int XGDMatrixCreateFromCudaColumnar(char const *data,
|
|||||||
char const* c_json_config,
|
char const* c_json_config,
|
||||||
DMatrixHandle *out) {
|
DMatrixHandle *out) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
|
||||||
|
xgboost_CHECK_C_ARG_PTR(c_json_config);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(data);
|
||||||
|
|
||||||
std::string json_str{data};
|
std::string json_str{data};
|
||||||
auto config = Json::Load(StringView{c_json_config});
|
auto config = Json::Load(StringView{c_json_config});
|
||||||
|
|
||||||
float missing = GetMissing(config);
|
float missing = GetMissing(config);
|
||||||
auto nthread = get<Integer const>(config["nthread"]);
|
auto nthread = get<Integer const>(config["nthread"]);
|
||||||
data::CudfAdapter adapter(json_str);
|
data::CudfAdapter adapter(json_str);
|
||||||
@ -97,6 +102,7 @@ int InplacePreidctCuda(BoosterHandle handle, char const *c_array_interface,
|
|||||||
}
|
}
|
||||||
auto proxy = dynamic_cast<data::DMatrixProxy *>(p_m.get());
|
auto proxy = dynamic_cast<data::DMatrixProxy *>(p_m.get());
|
||||||
CHECK(proxy) << "Invalid input type for inplace predict.";
|
CHECK(proxy) << "Invalid input type for inplace predict.";
|
||||||
|
|
||||||
proxy->SetCUDAArray(c_array_interface);
|
proxy->SetCUDAArray(c_array_interface);
|
||||||
|
|
||||||
auto config = Json::Load(StringView{c_json_config});
|
auto config = Json::Load(StringView{c_json_config});
|
||||||
@ -117,6 +123,11 @@ int InplacePreidctCuda(BoosterHandle handle, char const *c_array_interface,
|
|||||||
size_t n_samples = p_m->Info().num_row_;
|
size_t n_samples = p_m->Info().num_row_;
|
||||||
auto chunksize = n_samples == 0 ? 0 : p_predt->Size() / n_samples;
|
auto chunksize = n_samples == 0 ? 0 : p_predt->Size() / n_samples;
|
||||||
bool strict_shape = RequiredArg<Boolean>(config, "strict_shape", __func__);
|
bool strict_shape = RequiredArg<Boolean>(config, "strict_shape", __func__);
|
||||||
|
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_result);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_shape);
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_dim);
|
||||||
|
|
||||||
CalcPredictShape(strict_shape, type, n_samples, p_m->Info().num_col_, chunksize,
|
CalcPredictShape(strict_shape, type, n_samples, p_m->Info().num_col_, chunksize,
|
||||||
learner->Groups(), learner->BoostedRounds(), &shape, out_dim);
|
learner->Groups(), learner->BoostedRounds(), &shape, out_dim);
|
||||||
*out_shape = dmlc::BeginPtr(shape);
|
*out_shape = dmlc::BeginPtr(shape);
|
||||||
@ -130,6 +141,7 @@ XGB_DLL int XGBoosterPredictFromCudaColumnar(BoosterHandle handle, char const *c
|
|||||||
xgboost::bst_ulong *out_dim,
|
xgboost::bst_ulong *out_dim,
|
||||||
const float **out_result) {
|
const float **out_result) {
|
||||||
std::shared_ptr<DMatrix> p_m{nullptr};
|
std::shared_ptr<DMatrix> p_m{nullptr};
|
||||||
|
xgboost_CHECK_C_ARG_PTR(c_json_config);
|
||||||
if (m) {
|
if (m) {
|
||||||
p_m = *static_cast<std::shared_ptr<DMatrix> *>(m);
|
p_m = *static_cast<std::shared_ptr<DMatrix> *>(m);
|
||||||
}
|
}
|
||||||
@ -145,6 +157,7 @@ XGB_DLL int XGBoosterPredictFromCudaArray(BoosterHandle handle, char const *c_js
|
|||||||
if (m) {
|
if (m) {
|
||||||
p_m = *static_cast<std::shared_ptr<DMatrix> *>(m);
|
p_m = *static_cast<std::shared_ptr<DMatrix> *>(m);
|
||||||
}
|
}
|
||||||
|
xgboost_CHECK_C_ARG_PTR(out_result);
|
||||||
return InplacePreidctCuda(handle, c_json_strs, c_json_config, p_m, out_shape, out_dim,
|
return InplacePreidctCuda(handle, c_json_strs, c_json_config, p_m, out_shape, out_dim,
|
||||||
out_result);
|
out_result);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright (c) 2015 by Contributors
|
* Copyright (c) 2015-2022 by Contributors
|
||||||
* \file c_api_error.h
|
* \file c_api_error.h
|
||||||
* \brief Error handling for C API.
|
* \brief Error handling for C API.
|
||||||
*/
|
*/
|
||||||
@ -52,4 +52,12 @@ inline int XGBAPIHandleException(const dmlc::Error &e) {
|
|||||||
XGBAPISetLastError(e.what());
|
XGBAPISetLastError(e.what());
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define xgboost_CHECK_C_ARG_PTR(out_ptr) \
|
||||||
|
do { \
|
||||||
|
if (XGBOOST_EXPECT(!(out_ptr), false)) { \
|
||||||
|
LOG(FATAL) << "Invalid pointer argument: " << #out_ptr; \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
#endif // XGBOOST_C_API_C_API_ERROR_H_
|
#endif // XGBOOST_C_API_C_API_ERROR_H_
|
||||||
|
|||||||
@ -514,6 +514,7 @@ void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) {
|
|||||||
|
|
||||||
void MetaInfo::SetInfo(Context const& ctx, const char* key, const void* dptr, DataType dtype,
|
void MetaInfo::SetInfo(Context const& ctx, const char* key, const void* dptr, DataType dtype,
|
||||||
size_t num) {
|
size_t num) {
|
||||||
|
CHECK(key);
|
||||||
auto proc = [&](auto cast_d_ptr) {
|
auto proc = [&](auto cast_d_ptr) {
|
||||||
using T = std::remove_pointer_t<decltype(cast_d_ptr)>;
|
using T = std::remove_pointer_t<decltype(cast_d_ptr)>;
|
||||||
auto t = linalg::TensorView<T, 1>(common::Span<T>{cast_d_ptr, num}, {num}, Context::kCpuId);
|
auto t = linalg::TensorView<T, 1>(common::Span<T>{cast_d_ptr, num}, {num}, Context::kCpuId);
|
||||||
@ -588,8 +589,8 @@ void MetaInfo::GetInfo(char const* key, bst_ulong* out_len, DataType dtype,
|
|||||||
|
|
||||||
void MetaInfo::SetFeatureInfo(const char* key, const char **info, const bst_ulong size) {
|
void MetaInfo::SetFeatureInfo(const char* key, const char **info, const bst_ulong size) {
|
||||||
if (size != 0 && this->num_col_ != 0) {
|
if (size != 0 && this->num_col_ != 0) {
|
||||||
CHECK_EQ(size, this->num_col_)
|
CHECK_EQ(size, this->num_col_) << "Length of " << key << " must be equal to number of columns.";
|
||||||
<< "Length of " << key << " must be equal to number of columns.";
|
CHECK(info);
|
||||||
}
|
}
|
||||||
if (!std::strcmp(key, "feature_type")) {
|
if (!std::strcmp(key, "feature_type")) {
|
||||||
feature_type_names.clear();
|
feature_type_names.clear();
|
||||||
|
|||||||
@ -316,4 +316,12 @@ TEST(CAPI, BuildInfo) {
|
|||||||
ASSERT_TRUE(get<Object const>(loaded).find("USE_CUDA") != get<Object const>(loaded).cend());
|
ASSERT_TRUE(get<Object const>(loaded).find("USE_CUDA") != get<Object const>(loaded).cend());
|
||||||
ASSERT_TRUE(get<Object const>(loaded).find("USE_NCCL") != get<Object const>(loaded).cend());
|
ASSERT_TRUE(get<Object const>(loaded).find("USE_NCCL") != get<Object const>(loaded).cend());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, NullPtr) {
|
||||||
|
ASSERT_EQ(XGBSetGlobalConfig(nullptr), -1);
|
||||||
|
auto const *err = XGBGetLastError();
|
||||||
|
auto pos = std::string{err}.find("Invalid pointer argument: json_str");
|
||||||
|
ASSERT_NE(pos, std::string::npos);
|
||||||
|
XGBAPISetLastError("");
|
||||||
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user