Convert handle == nullptr from SegFault to user-friendly error. (#3021)

* Convert SegFault to user-friendly error.

* Apply the change to DMatrix API as well
This commit is contained in:
Yun Ni 2018-06-28 23:30:26 -07:00 committed by Philip Hyunsu Cho
parent 8bec8d5e9a
commit 30d10ab035
3 changed files with 35 additions and 5 deletions

View File

@ -556,11 +556,12 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict
bst_ulong len; bst_ulong len;
float *result; float *result;
int ret = XGBoosterPredict(handle, dmat, joption_mask, (unsigned int) jntree_limit, &len, (const float **) &result); int ret = XGBoosterPredict(handle, dmat, joption_mask, (unsigned int) jntree_limit, &len, (const float **) &result);
if (len) {
jsize jlen = (jsize) len; jsize jlen = (jsize) len;
jfloatArray jarray = jenv->NewFloatArray(jlen); jfloatArray jarray = jenv->NewFloatArray(jlen);
jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) result); jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) result);
jenv->SetObjectArrayElement(jout, 0, jarray); jenv->SetObjectArrayElement(jout, 0, jarray);
}
return ret; return ret;
} }

View File

@ -650,6 +650,7 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource()); std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
API_BEGIN(); API_BEGIN();
CHECK_HANDLE();
data::SimpleCSRSource src; data::SimpleCSRSource src;
src.CopyFrom(static_cast<std::shared_ptr<DMatrix>*>(handle)->get()); src.CopyFrom(static_cast<std::shared_ptr<DMatrix>*>(handle)->get());
data::SimpleCSRSource& ret = *source; data::SimpleCSRSource& ret = *source;
@ -694,6 +695,7 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
XGB_DLL int XGDMatrixFree(DMatrixHandle handle) { XGB_DLL int XGDMatrixFree(DMatrixHandle handle) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE();
delete static_cast<std::shared_ptr<DMatrix>*>(handle); delete static_cast<std::shared_ptr<DMatrix>*>(handle);
API_END(); API_END();
} }
@ -702,6 +704,7 @@ XGB_DLL int XGDMatrixSaveBinary(DMatrixHandle handle,
const char* fname, const char* fname,
int silent) { int silent) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->SaveToLocalFile(fname); static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->SaveToLocalFile(fname);
API_END(); API_END();
} }
@ -711,6 +714,7 @@ XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
const bst_float* info, const bst_float* info,
xgboost::bst_ulong len) { xgboost::bst_ulong len) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle) static_cast<std::shared_ptr<DMatrix>*>(handle)
->get()->Info().SetInfo(field, info, kFloat32, len); ->get()->Info().SetInfo(field, info, kFloat32, len);
API_END(); API_END();
@ -721,6 +725,7 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
const unsigned* info, const unsigned* info,
xgboost::bst_ulong len) { xgboost::bst_ulong len) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle) static_cast<std::shared_ptr<DMatrix>*>(handle)
->get()->Info().SetInfo(field, info, kUInt32, len); ->get()->Info().SetInfo(field, info, kUInt32, len);
API_END(); API_END();
@ -730,6 +735,7 @@ XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
const unsigned* group, const unsigned* group,
xgboost::bst_ulong len) { xgboost::bst_ulong len) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE();
auto *pmat = static_cast<std::shared_ptr<DMatrix>*>(handle); auto *pmat = static_cast<std::shared_ptr<DMatrix>*>(handle);
MetaInfo& info = pmat->get()->Info(); MetaInfo& info = pmat->get()->Info();
info.group_ptr_.resize(len + 1); info.group_ptr_.resize(len + 1);
@ -745,6 +751,7 @@ XGB_DLL int XGDMatrixGetFloatInfo(const DMatrixHandle handle,
xgboost::bst_ulong* out_len, xgboost::bst_ulong* out_len,
const bst_float** out_dptr) { const bst_float** out_dptr) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE();
const MetaInfo& info = static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info(); const MetaInfo& info = static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info();
const std::vector<bst_float>* vec = nullptr; const std::vector<bst_float>* vec = nullptr;
if (!std::strcmp(field, "label")) { if (!std::strcmp(field, "label")) {
@ -766,6 +773,7 @@ XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
xgboost::bst_ulong *out_len, xgboost::bst_ulong *out_len,
const unsigned **out_dptr) { const unsigned **out_dptr) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE();
const MetaInfo& info = static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info(); const MetaInfo& info = static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info();
const std::vector<unsigned>* vec = nullptr; const std::vector<unsigned>* vec = nullptr;
if (!std::strcmp(field, "root_index")) { if (!std::strcmp(field, "root_index")) {
@ -781,6 +789,7 @@ XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
XGB_DLL int XGDMatrixNumRow(const DMatrixHandle handle, XGB_DLL int XGDMatrixNumRow(const DMatrixHandle handle,
xgboost::bst_ulong *out) { xgboost::bst_ulong *out) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE();
*out = static_cast<xgboost::bst_ulong>( *out = static_cast<xgboost::bst_ulong>(
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info().num_row_); static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info().num_row_);
API_END(); API_END();
@ -789,6 +798,7 @@ XGB_DLL int XGDMatrixNumRow(const DMatrixHandle handle,
XGB_DLL int XGDMatrixNumCol(const DMatrixHandle handle, XGB_DLL int XGDMatrixNumCol(const DMatrixHandle handle,
xgboost::bst_ulong *out) { xgboost::bst_ulong *out) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE();
*out = static_cast<size_t>( *out = static_cast<size_t>(
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info().num_col_); static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info().num_col_);
API_END(); API_END();
@ -809,6 +819,7 @@ XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[],
XGB_DLL int XGBoosterFree(BoosterHandle handle) { XGB_DLL int XGBoosterFree(BoosterHandle handle) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE();
delete static_cast<Booster*>(handle); delete static_cast<Booster*>(handle);
API_END(); API_END();
} }
@ -817,6 +828,7 @@ XGB_DLL int XGBoosterSetParam(BoosterHandle handle,
const char *name, const char *name,
const char *value) { const char *value) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE();
static_cast<Booster*>(handle)->SetParam(name, value); static_cast<Booster*>(handle)->SetParam(name, value);
API_END(); API_END();
} }
@ -825,6 +837,7 @@ XGB_DLL int XGBoosterUpdateOneIter(BoosterHandle handle,
int iter, int iter,
DMatrixHandle dtrain) { DMatrixHandle dtrain) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE();
auto* bst = static_cast<Booster*>(handle); auto* bst = static_cast<Booster*>(handle);
auto *dtr = auto *dtr =
static_cast<std::shared_ptr<DMatrix>*>(dtrain); static_cast<std::shared_ptr<DMatrix>*>(dtrain);
@ -841,6 +854,7 @@ XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle,
xgboost::bst_ulong len) { xgboost::bst_ulong len) {
HostDeviceVector<GradientPair>& tmp_gpair = XGBAPIThreadLocalStore::Get()->tmp_gpair; HostDeviceVector<GradientPair>& tmp_gpair = XGBAPIThreadLocalStore::Get()->tmp_gpair;
API_BEGIN(); API_BEGIN();
CHECK_HANDLE();
auto* bst = static_cast<Booster*>(handle); auto* bst = static_cast<Booster*>(handle);
auto* dtr = auto* dtr =
static_cast<std::shared_ptr<DMatrix>*>(dtrain); static_cast<std::shared_ptr<DMatrix>*>(dtrain);
@ -863,6 +877,7 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle,
const char** out_str) { const char** out_str) {
std::string& eval_str = XGBAPIThreadLocalStore::Get()->ret_str; std::string& eval_str = XGBAPIThreadLocalStore::Get()->ret_str;
API_BEGIN(); API_BEGIN();
CHECK_HANDLE();
auto* bst = static_cast<Booster*>(handle); auto* bst = static_cast<Booster*>(handle);
std::vector<DMatrix*> data_sets; std::vector<DMatrix*> data_sets;
std::vector<std::string> data_names; std::vector<std::string> data_names;
@ -887,6 +902,7 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
HostDeviceVector<bst_float>& preds = HostDeviceVector<bst_float>& preds =
XGBAPIThreadLocalStore::Get()->ret_vec_float; XGBAPIThreadLocalStore::Get()->ret_vec_float;
API_BEGIN(); API_BEGIN();
CHECK_HANDLE();
auto *bst = static_cast<Booster*>(handle); auto *bst = static_cast<Booster*>(handle);
bst->LazyInit(); bst->LazyInit();
bst->learner()->Predict( bst->learner()->Predict(
@ -904,6 +920,7 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) { XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE();
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r")); std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
static_cast<Booster*>(handle)->LoadModel(fi.get()); static_cast<Booster*>(handle)->LoadModel(fi.get());
API_END(); API_END();
@ -911,6 +928,7 @@ XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char* fname) { XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char* fname) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE();
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname, "w")); std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname, "w"));
auto *bst = static_cast<Booster*>(handle); auto *bst = static_cast<Booster*>(handle);
bst->LazyInit(); bst->LazyInit();
@ -922,6 +940,7 @@ XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle,
const void* buf, const void* buf,
xgboost::bst_ulong len) { xgboost::bst_ulong len) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE();
common::MemoryFixSizeBuffer fs((void*)buf, len); // NOLINT(*) common::MemoryFixSizeBuffer fs((void*)buf, len); // NOLINT(*)
static_cast<Booster*>(handle)->LoadModel(&fs); static_cast<Booster*>(handle)->LoadModel(&fs);
API_END(); API_END();
@ -934,6 +953,7 @@ XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle,
raw_str.resize(0); raw_str.resize(0);
API_BEGIN(); API_BEGIN();
CHECK_HANDLE();
common::MemoryBufferStream fo(&raw_str); common::MemoryBufferStream fo(&raw_str);
auto *bst = static_cast<Booster*>(handle); auto *bst = static_cast<Booster*>(handle);
bst->LazyInit(); bst->LazyInit();
@ -976,6 +996,7 @@ XGB_DLL int XGBoosterDumpModelEx(BoosterHandle handle,
xgboost::bst_ulong* len, xgboost::bst_ulong* len,
const char*** out_models) { const char*** out_models) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE();
FeatureMap featmap; FeatureMap featmap;
if (strlen(fmap) != 0) { if (strlen(fmap) != 0) {
std::unique_ptr<dmlc::Stream> fs( std::unique_ptr<dmlc::Stream> fs(
@ -1006,6 +1027,7 @@ XGB_DLL int XGBoosterDumpModelExWithFeatures(BoosterHandle handle,
xgboost::bst_ulong* len, xgboost::bst_ulong* len,
const char*** out_models) { const char*** out_models) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE();
FeatureMap featmap; FeatureMap featmap;
for (int i = 0; i < fnum; ++i) { for (int i = 0; i < fnum; ++i) {
featmap.PushBack(i, fname[i], ftype[i]); featmap.PushBack(i, fname[i], ftype[i]);
@ -1021,6 +1043,7 @@ XGB_DLL int XGBoosterGetAttr(BoosterHandle handle,
auto* bst = static_cast<Booster*>(handle); auto* bst = static_cast<Booster*>(handle);
std::string& ret_str = XGBAPIThreadLocalStore::Get()->ret_str; std::string& ret_str = XGBAPIThreadLocalStore::Get()->ret_str;
API_BEGIN(); API_BEGIN();
CHECK_HANDLE();
if (bst->learner()->GetAttr(key, &ret_str)) { if (bst->learner()->GetAttr(key, &ret_str)) {
*out = ret_str.c_str(); *out = ret_str.c_str();
*success = 1; *success = 1;
@ -1036,6 +1059,7 @@ XGB_DLL int XGBoosterSetAttr(BoosterHandle handle,
const char* value) { const char* value) {
auto* bst = static_cast<Booster*>(handle); auto* bst = static_cast<Booster*>(handle);
API_BEGIN(); API_BEGIN();
CHECK_HANDLE();
if (value == nullptr) { if (value == nullptr) {
bst->learner()->DelAttr(key); bst->learner()->DelAttr(key);
} else { } else {
@ -1051,6 +1075,7 @@ XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle,
std::vector<const char*>& charp_vecs = XGBAPIThreadLocalStore::Get()->ret_vec_charp; std::vector<const char*>& charp_vecs = XGBAPIThreadLocalStore::Get()->ret_vec_charp;
auto *bst = static_cast<Booster*>(handle); auto *bst = static_cast<Booster*>(handle);
API_BEGIN(); API_BEGIN();
CHECK_HANDLE();
str_vecs = bst->learner()->GetAttrNames(); str_vecs = bst->learner()->GetAttrNames();
charp_vecs.resize(str_vecs.size()); charp_vecs.resize(str_vecs.size());
for (size_t i = 0; i < str_vecs.size(); ++i) { for (size_t i = 0; i < str_vecs.size(); ++i) {
@ -1064,6 +1089,7 @@ XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle,
XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle, XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
int* version) { int* version) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE();
auto* bst = static_cast<Booster*>(handle); auto* bst = static_cast<Booster*>(handle);
*version = rabit::LoadCheckPoint(bst->learner()); *version = rabit::LoadCheckPoint(bst->learner());
if (*version != 0) { if (*version != 0) {
@ -1074,6 +1100,7 @@ XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) { XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE();
auto* bst = static_cast<Booster*>(handle); auto* bst = static_cast<Booster*>(handle);
if (bst->learner()->AllowLazyCheckPoint()) { if (bst->learner()->AllowLazyCheckPoint()) {
rabit::LazyCheckPoint(bst->learner()); rabit::LazyCheckPoint(bst->learner());

View File

@ -15,6 +15,8 @@
/*! \brief every function starts with API_BEGIN(); /*! \brief every function starts with API_BEGIN();
and finishes with API_END() or API_END_HANDLE_ERROR */ and finishes with API_END() or API_END_HANDLE_ERROR */
#define API_END() } catch(dmlc::Error &_except_) { return XGBAPIHandleException(_except_); } return 0; // NOLINT(*) #define API_END() } catch(dmlc::Error &_except_) { return XGBAPIHandleException(_except_); } return 0; // NOLINT(*)
#define CHECK_HANDLE() if (handle == nullptr) \
LOG(FATAL) << "DMatrix/Booster has not been intialized or has already been disposed.";
/*! /*!
* \brief every function starts with API_BEGIN(); * \brief every function starts with API_BEGIN();
* and finishes with API_END() or API_END_HANDLE_ERROR * and finishes with API_END() or API_END_HANDLE_ERROR