Convert handle == nullptr from SegFault to user-friendly error. (#3021)
* Convert SegFault to user-friendly error. * Apply the change to DMatrix API as well
This commit is contained in:
parent
8bec8d5e9a
commit
30d10ab035
@ -556,11 +556,12 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict
|
||||
bst_ulong len;
|
||||
float *result;
|
||||
int ret = XGBoosterPredict(handle, dmat, joption_mask, (unsigned int) jntree_limit, &len, (const float **) &result);
|
||||
|
||||
if (len) {
|
||||
jsize jlen = (jsize) len;
|
||||
jfloatArray jarray = jenv->NewFloatArray(jlen);
|
||||
jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) result);
|
||||
jenv->SetObjectArrayElement(jout, 0, jarray);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
@ -650,6 +650,7 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
|
||||
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
|
||||
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
data::SimpleCSRSource src;
|
||||
src.CopyFrom(static_cast<std::shared_ptr<DMatrix>*>(handle)->get());
|
||||
data::SimpleCSRSource& ret = *source;
|
||||
@ -694,6 +695,7 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
|
||||
|
||||
XGB_DLL int XGDMatrixFree(DMatrixHandle handle) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
delete static_cast<std::shared_ptr<DMatrix>*>(handle);
|
||||
API_END();
|
||||
}
|
||||
@ -702,6 +704,7 @@ XGB_DLL int XGDMatrixSaveBinary(DMatrixHandle handle,
|
||||
const char* fname,
|
||||
int silent) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->SaveToLocalFile(fname);
|
||||
API_END();
|
||||
}
|
||||
@ -711,6 +714,7 @@ XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
|
||||
const bst_float* info,
|
||||
xgboost::bst_ulong len) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
||||
->get()->Info().SetInfo(field, info, kFloat32, len);
|
||||
API_END();
|
||||
@ -721,6 +725,7 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
|
||||
const unsigned* info,
|
||||
xgboost::bst_ulong len) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
||||
->get()->Info().SetInfo(field, info, kUInt32, len);
|
||||
API_END();
|
||||
@ -730,6 +735,7 @@ XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
|
||||
const unsigned* group,
|
||||
xgboost::bst_ulong len) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto *pmat = static_cast<std::shared_ptr<DMatrix>*>(handle);
|
||||
MetaInfo& info = pmat->get()->Info();
|
||||
info.group_ptr_.resize(len + 1);
|
||||
@ -745,6 +751,7 @@ XGB_DLL int XGDMatrixGetFloatInfo(const DMatrixHandle handle,
|
||||
xgboost::bst_ulong* out_len,
|
||||
const bst_float** out_dptr) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
const MetaInfo& info = static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info();
|
||||
const std::vector<bst_float>* vec = nullptr;
|
||||
if (!std::strcmp(field, "label")) {
|
||||
@ -766,6 +773,7 @@ XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
|
||||
xgboost::bst_ulong *out_len,
|
||||
const unsigned **out_dptr) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
const MetaInfo& info = static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info();
|
||||
const std::vector<unsigned>* vec = nullptr;
|
||||
if (!std::strcmp(field, "root_index")) {
|
||||
@ -781,6 +789,7 @@ XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
|
||||
XGB_DLL int XGDMatrixNumRow(const DMatrixHandle handle,
|
||||
xgboost::bst_ulong *out) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
*out = static_cast<xgboost::bst_ulong>(
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info().num_row_);
|
||||
API_END();
|
||||
@ -789,6 +798,7 @@ XGB_DLL int XGDMatrixNumRow(const DMatrixHandle handle,
|
||||
XGB_DLL int XGDMatrixNumCol(const DMatrixHandle handle,
|
||||
xgboost::bst_ulong *out) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
*out = static_cast<size_t>(
|
||||
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info().num_col_);
|
||||
API_END();
|
||||
@ -809,6 +819,7 @@ XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[],
|
||||
|
||||
XGB_DLL int XGBoosterFree(BoosterHandle handle) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
delete static_cast<Booster*>(handle);
|
||||
API_END();
|
||||
}
|
||||
@ -817,6 +828,7 @@ XGB_DLL int XGBoosterSetParam(BoosterHandle handle,
|
||||
const char *name,
|
||||
const char *value) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
static_cast<Booster*>(handle)->SetParam(name, value);
|
||||
API_END();
|
||||
}
|
||||
@ -825,6 +837,7 @@ XGB_DLL int XGBoosterUpdateOneIter(BoosterHandle handle,
|
||||
int iter,
|
||||
DMatrixHandle dtrain) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto* bst = static_cast<Booster*>(handle);
|
||||
auto *dtr =
|
||||
static_cast<std::shared_ptr<DMatrix>*>(dtrain);
|
||||
@ -841,6 +854,7 @@ XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle,
|
||||
xgboost::bst_ulong len) {
|
||||
HostDeviceVector<GradientPair>& tmp_gpair = XGBAPIThreadLocalStore::Get()->tmp_gpair;
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto* bst = static_cast<Booster*>(handle);
|
||||
auto* dtr =
|
||||
static_cast<std::shared_ptr<DMatrix>*>(dtrain);
|
||||
@ -863,6 +877,7 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle,
|
||||
const char** out_str) {
|
||||
std::string& eval_str = XGBAPIThreadLocalStore::Get()->ret_str;
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto* bst = static_cast<Booster*>(handle);
|
||||
std::vector<DMatrix*> data_sets;
|
||||
std::vector<std::string> data_names;
|
||||
@ -887,6 +902,7 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
||||
HostDeviceVector<bst_float>& preds =
|
||||
XGBAPIThreadLocalStore::Get()->ret_vec_float;
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto *bst = static_cast<Booster*>(handle);
|
||||
bst->LazyInit();
|
||||
bst->learner()->Predict(
|
||||
@ -904,6 +920,7 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
||||
|
||||
XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
|
||||
static_cast<Booster*>(handle)->LoadModel(fi.get());
|
||||
API_END();
|
||||
@ -911,6 +928,7 @@ XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
|
||||
|
||||
XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char* fname) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname, "w"));
|
||||
auto *bst = static_cast<Booster*>(handle);
|
||||
bst->LazyInit();
|
||||
@ -922,6 +940,7 @@ XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle,
|
||||
const void* buf,
|
||||
xgboost::bst_ulong len) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
common::MemoryFixSizeBuffer fs((void*)buf, len); // NOLINT(*)
|
||||
static_cast<Booster*>(handle)->LoadModel(&fs);
|
||||
API_END();
|
||||
@ -934,6 +953,7 @@ XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle,
|
||||
raw_str.resize(0);
|
||||
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
common::MemoryBufferStream fo(&raw_str);
|
||||
auto *bst = static_cast<Booster*>(handle);
|
||||
bst->LazyInit();
|
||||
@ -976,6 +996,7 @@ XGB_DLL int XGBoosterDumpModelEx(BoosterHandle handle,
|
||||
xgboost::bst_ulong* len,
|
||||
const char*** out_models) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
FeatureMap featmap;
|
||||
if (strlen(fmap) != 0) {
|
||||
std::unique_ptr<dmlc::Stream> fs(
|
||||
@ -1006,6 +1027,7 @@ XGB_DLL int XGBoosterDumpModelExWithFeatures(BoosterHandle handle,
|
||||
xgboost::bst_ulong* len,
|
||||
const char*** out_models) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
FeatureMap featmap;
|
||||
for (int i = 0; i < fnum; ++i) {
|
||||
featmap.PushBack(i, fname[i], ftype[i]);
|
||||
@ -1021,6 +1043,7 @@ XGB_DLL int XGBoosterGetAttr(BoosterHandle handle,
|
||||
auto* bst = static_cast<Booster*>(handle);
|
||||
std::string& ret_str = XGBAPIThreadLocalStore::Get()->ret_str;
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
if (bst->learner()->GetAttr(key, &ret_str)) {
|
||||
*out = ret_str.c_str();
|
||||
*success = 1;
|
||||
@ -1036,6 +1059,7 @@ XGB_DLL int XGBoosterSetAttr(BoosterHandle handle,
|
||||
const char* value) {
|
||||
auto* bst = static_cast<Booster*>(handle);
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
if (value == nullptr) {
|
||||
bst->learner()->DelAttr(key);
|
||||
} else {
|
||||
@ -1051,6 +1075,7 @@ XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle,
|
||||
std::vector<const char*>& charp_vecs = XGBAPIThreadLocalStore::Get()->ret_vec_charp;
|
||||
auto *bst = static_cast<Booster*>(handle);
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
str_vecs = bst->learner()->GetAttrNames();
|
||||
charp_vecs.resize(str_vecs.size());
|
||||
for (size_t i = 0; i < str_vecs.size(); ++i) {
|
||||
@ -1064,6 +1089,7 @@ XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle,
|
||||
XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
|
||||
int* version) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto* bst = static_cast<Booster*>(handle);
|
||||
*version = rabit::LoadCheckPoint(bst->learner());
|
||||
if (*version != 0) {
|
||||
@ -1074,6 +1100,7 @@ XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
|
||||
|
||||
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto* bst = static_cast<Booster*>(handle);
|
||||
if (bst->learner()->AllowLazyCheckPoint()) {
|
||||
rabit::LazyCheckPoint(bst->learner());
|
||||
|
||||
@ -15,6 +15,8 @@
|
||||
/*! \brief every function starts with API_BEGIN();
|
||||
and finishes with API_END() or API_END_HANDLE_ERROR */
|
||||
#define API_END() } catch(dmlc::Error &_except_) { return XGBAPIHandleException(_except_); } return 0; // NOLINT(*)
|
||||
#define CHECK_HANDLE() if (handle == nullptr) \
|
||||
LOG(FATAL) << "DMatrix/Booster has not been intialized or has already been disposed.";
|
||||
/*!
|
||||
* \brief every function starts with API_BEGIN();
|
||||
* and finishes with API_END() or API_END_HANDLE_ERROR
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user