Convert handle == nullptr from SegFault to user-friendly error. (#3021)

* Convert SegFault to user-friendly error.

* Apply the change to DMatrix API as well
This commit is contained in:
Yun Ni 2018-06-28 23:30:26 -07:00 committed by Philip Hyunsu Cho
parent 8bec8d5e9a
commit 30d10ab035
3 changed files with 35 additions and 5 deletions

View File

@ -556,11 +556,12 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict
bst_ulong len;
float *result;
int ret = XGBoosterPredict(handle, dmat, joption_mask, (unsigned int) jntree_limit, &len, (const float **) &result);
if (len) {
jsize jlen = (jsize) len;
jfloatArray jarray = jenv->NewFloatArray(jlen);
jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) result);
jenv->SetObjectArrayElement(jout, 0, jarray);
}
return ret;
}

View File

@ -650,6 +650,7 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
API_BEGIN();
CHECK_HANDLE();
data::SimpleCSRSource src;
src.CopyFrom(static_cast<std::shared_ptr<DMatrix>*>(handle)->get());
data::SimpleCSRSource& ret = *source;
@ -694,6 +695,7 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
XGB_DLL int XGDMatrixFree(DMatrixHandle handle) {
API_BEGIN();
CHECK_HANDLE();
delete static_cast<std::shared_ptr<DMatrix>*>(handle);
API_END();
}
@ -702,6 +704,7 @@ XGB_DLL int XGDMatrixSaveBinary(DMatrixHandle handle,
const char* fname,
int silent) {
API_BEGIN();
CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->SaveToLocalFile(fname);
API_END();
}
@ -711,6 +714,7 @@ XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
const bst_float* info,
xgboost::bst_ulong len) {
API_BEGIN();
CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle)
->get()->Info().SetInfo(field, info, kFloat32, len);
API_END();
@ -721,6 +725,7 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
const unsigned* info,
xgboost::bst_ulong len) {
API_BEGIN();
CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle)
->get()->Info().SetInfo(field, info, kUInt32, len);
API_END();
@ -730,6 +735,7 @@ XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
const unsigned* group,
xgboost::bst_ulong len) {
API_BEGIN();
CHECK_HANDLE();
auto *pmat = static_cast<std::shared_ptr<DMatrix>*>(handle);
MetaInfo& info = pmat->get()->Info();
info.group_ptr_.resize(len + 1);
@ -745,6 +751,7 @@ XGB_DLL int XGDMatrixGetFloatInfo(const DMatrixHandle handle,
xgboost::bst_ulong* out_len,
const bst_float** out_dptr) {
API_BEGIN();
CHECK_HANDLE();
const MetaInfo& info = static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info();
const std::vector<bst_float>* vec = nullptr;
if (!std::strcmp(field, "label")) {
@ -766,6 +773,7 @@ XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
xgboost::bst_ulong *out_len,
const unsigned **out_dptr) {
API_BEGIN();
CHECK_HANDLE();
const MetaInfo& info = static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info();
const std::vector<unsigned>* vec = nullptr;
if (!std::strcmp(field, "root_index")) {
@ -781,6 +789,7 @@ XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
XGB_DLL int XGDMatrixNumRow(const DMatrixHandle handle,
xgboost::bst_ulong *out) {
API_BEGIN();
CHECK_HANDLE();
*out = static_cast<xgboost::bst_ulong>(
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info().num_row_);
API_END();
@ -789,6 +798,7 @@ XGB_DLL int XGDMatrixNumRow(const DMatrixHandle handle,
XGB_DLL int XGDMatrixNumCol(const DMatrixHandle handle,
xgboost::bst_ulong *out) {
API_BEGIN();
CHECK_HANDLE();
*out = static_cast<size_t>(
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info().num_col_);
API_END();
@ -809,6 +819,7 @@ XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[],
XGB_DLL int XGBoosterFree(BoosterHandle handle) {
API_BEGIN();
CHECK_HANDLE();
delete static_cast<Booster*>(handle);
API_END();
}
@ -817,6 +828,7 @@ XGB_DLL int XGBoosterSetParam(BoosterHandle handle,
const char *name,
const char *value) {
API_BEGIN();
CHECK_HANDLE();
static_cast<Booster*>(handle)->SetParam(name, value);
API_END();
}
@ -825,6 +837,7 @@ XGB_DLL int XGBoosterUpdateOneIter(BoosterHandle handle,
int iter,
DMatrixHandle dtrain) {
API_BEGIN();
CHECK_HANDLE();
auto* bst = static_cast<Booster*>(handle);
auto *dtr =
static_cast<std::shared_ptr<DMatrix>*>(dtrain);
@ -841,6 +854,7 @@ XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle,
xgboost::bst_ulong len) {
HostDeviceVector<GradientPair>& tmp_gpair = XGBAPIThreadLocalStore::Get()->tmp_gpair;
API_BEGIN();
CHECK_HANDLE();
auto* bst = static_cast<Booster*>(handle);
auto* dtr =
static_cast<std::shared_ptr<DMatrix>*>(dtrain);
@ -863,6 +877,7 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle,
const char** out_str) {
std::string& eval_str = XGBAPIThreadLocalStore::Get()->ret_str;
API_BEGIN();
CHECK_HANDLE();
auto* bst = static_cast<Booster*>(handle);
std::vector<DMatrix*> data_sets;
std::vector<std::string> data_names;
@ -887,6 +902,7 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
HostDeviceVector<bst_float>& preds =
XGBAPIThreadLocalStore::Get()->ret_vec_float;
API_BEGIN();
CHECK_HANDLE();
auto *bst = static_cast<Booster*>(handle);
bst->LazyInit();
bst->learner()->Predict(
@ -904,6 +920,7 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
API_BEGIN();
CHECK_HANDLE();
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
static_cast<Booster*>(handle)->LoadModel(fi.get());
API_END();
@ -911,6 +928,7 @@ XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char* fname) {
API_BEGIN();
CHECK_HANDLE();
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname, "w"));
auto *bst = static_cast<Booster*>(handle);
bst->LazyInit();
@ -922,6 +940,7 @@ XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle,
const void* buf,
xgboost::bst_ulong len) {
API_BEGIN();
CHECK_HANDLE();
common::MemoryFixSizeBuffer fs((void*)buf, len); // NOLINT(*)
static_cast<Booster*>(handle)->LoadModel(&fs);
API_END();
@ -934,6 +953,7 @@ XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle,
raw_str.resize(0);
API_BEGIN();
CHECK_HANDLE();
common::MemoryBufferStream fo(&raw_str);
auto *bst = static_cast<Booster*>(handle);
bst->LazyInit();
@ -976,6 +996,7 @@ XGB_DLL int XGBoosterDumpModelEx(BoosterHandle handle,
xgboost::bst_ulong* len,
const char*** out_models) {
API_BEGIN();
CHECK_HANDLE();
FeatureMap featmap;
if (strlen(fmap) != 0) {
std::unique_ptr<dmlc::Stream> fs(
@ -1006,6 +1027,7 @@ XGB_DLL int XGBoosterDumpModelExWithFeatures(BoosterHandle handle,
xgboost::bst_ulong* len,
const char*** out_models) {
API_BEGIN();
CHECK_HANDLE();
FeatureMap featmap;
for (int i = 0; i < fnum; ++i) {
featmap.PushBack(i, fname[i], ftype[i]);
@ -1021,6 +1043,7 @@ XGB_DLL int XGBoosterGetAttr(BoosterHandle handle,
auto* bst = static_cast<Booster*>(handle);
std::string& ret_str = XGBAPIThreadLocalStore::Get()->ret_str;
API_BEGIN();
CHECK_HANDLE();
if (bst->learner()->GetAttr(key, &ret_str)) {
*out = ret_str.c_str();
*success = 1;
@ -1036,6 +1059,7 @@ XGB_DLL int XGBoosterSetAttr(BoosterHandle handle,
const char* value) {
auto* bst = static_cast<Booster*>(handle);
API_BEGIN();
CHECK_HANDLE();
if (value == nullptr) {
bst->learner()->DelAttr(key);
} else {
@ -1051,6 +1075,7 @@ XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle,
std::vector<const char*>& charp_vecs = XGBAPIThreadLocalStore::Get()->ret_vec_charp;
auto *bst = static_cast<Booster*>(handle);
API_BEGIN();
CHECK_HANDLE();
str_vecs = bst->learner()->GetAttrNames();
charp_vecs.resize(str_vecs.size());
for (size_t i = 0; i < str_vecs.size(); ++i) {
@ -1064,6 +1089,7 @@ XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle,
XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
int* version) {
API_BEGIN();
CHECK_HANDLE();
auto* bst = static_cast<Booster*>(handle);
*version = rabit::LoadCheckPoint(bst->learner());
if (*version != 0) {
@ -1074,6 +1100,7 @@ XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
API_BEGIN();
CHECK_HANDLE();
auto* bst = static_cast<Booster*>(handle);
if (bst->learner()->AllowLazyCheckPoint()) {
rabit::LazyCheckPoint(bst->learner());

View File

@ -15,6 +15,8 @@
/*! \brief every function starts with API_BEGIN();
and finishes with API_END() or API_END_HANDLE_ERROR */
#define API_END() } catch(dmlc::Error &_except_) { return XGBAPIHandleException(_except_); } return 0; // NOLINT(*)
#define CHECK_HANDLE() if (handle == nullptr) \
LOG(FATAL) << "DMatrix/Booster has not been intialized or has already been disposed.";
/*!
* \brief every function starts with API_BEGIN();
* and finishes with API_END() or API_END_HANDLE_ERROR