Convert handle == nullptr from SegFault to user-friendly error. (#3021)
* Convert SegFault to user-friendly error. * Apply the change to DMatrix API as well
This commit is contained in:
parent
8bec8d5e9a
commit
30d10ab035
@ -556,11 +556,12 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict
|
|||||||
bst_ulong len;
|
bst_ulong len;
|
||||||
float *result;
|
float *result;
|
||||||
int ret = XGBoosterPredict(handle, dmat, joption_mask, (unsigned int) jntree_limit, &len, (const float **) &result);
|
int ret = XGBoosterPredict(handle, dmat, joption_mask, (unsigned int) jntree_limit, &len, (const float **) &result);
|
||||||
|
if (len) {
|
||||||
jsize jlen = (jsize) len;
|
jsize jlen = (jsize) len;
|
||||||
jfloatArray jarray = jenv->NewFloatArray(jlen);
|
jfloatArray jarray = jenv->NewFloatArray(jlen);
|
||||||
jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) result);
|
jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) result);
|
||||||
jenv->SetObjectArrayElement(jout, 0, jarray);
|
jenv->SetObjectArrayElement(jout, 0, jarray);
|
||||||
|
}
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -650,6 +650,7 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
|
|||||||
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
|
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
|
||||||
|
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
CHECK_HANDLE();
|
||||||
data::SimpleCSRSource src;
|
data::SimpleCSRSource src;
|
||||||
src.CopyFrom(static_cast<std::shared_ptr<DMatrix>*>(handle)->get());
|
src.CopyFrom(static_cast<std::shared_ptr<DMatrix>*>(handle)->get());
|
||||||
data::SimpleCSRSource& ret = *source;
|
data::SimpleCSRSource& ret = *source;
|
||||||
@ -694,6 +695,7 @@ XGB_DLL int XGDMatrixSliceDMatrix(DMatrixHandle handle,
|
|||||||
|
|
||||||
XGB_DLL int XGDMatrixFree(DMatrixHandle handle) {
|
XGB_DLL int XGDMatrixFree(DMatrixHandle handle) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
CHECK_HANDLE();
|
||||||
delete static_cast<std::shared_ptr<DMatrix>*>(handle);
|
delete static_cast<std::shared_ptr<DMatrix>*>(handle);
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
@ -702,6 +704,7 @@ XGB_DLL int XGDMatrixSaveBinary(DMatrixHandle handle,
|
|||||||
const char* fname,
|
const char* fname,
|
||||||
int silent) {
|
int silent) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
CHECK_HANDLE();
|
||||||
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->SaveToLocalFile(fname);
|
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->SaveToLocalFile(fname);
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
@ -711,6 +714,7 @@ XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
|
|||||||
const bst_float* info,
|
const bst_float* info,
|
||||||
xgboost::bst_ulong len) {
|
xgboost::bst_ulong len) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
CHECK_HANDLE();
|
||||||
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
||||||
->get()->Info().SetInfo(field, info, kFloat32, len);
|
->get()->Info().SetInfo(field, info, kFloat32, len);
|
||||||
API_END();
|
API_END();
|
||||||
@ -721,6 +725,7 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
|
|||||||
const unsigned* info,
|
const unsigned* info,
|
||||||
xgboost::bst_ulong len) {
|
xgboost::bst_ulong len) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
CHECK_HANDLE();
|
||||||
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
static_cast<std::shared_ptr<DMatrix>*>(handle)
|
||||||
->get()->Info().SetInfo(field, info, kUInt32, len);
|
->get()->Info().SetInfo(field, info, kUInt32, len);
|
||||||
API_END();
|
API_END();
|
||||||
@ -730,6 +735,7 @@ XGB_DLL int XGDMatrixSetGroup(DMatrixHandle handle,
|
|||||||
const unsigned* group,
|
const unsigned* group,
|
||||||
xgboost::bst_ulong len) {
|
xgboost::bst_ulong len) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
CHECK_HANDLE();
|
||||||
auto *pmat = static_cast<std::shared_ptr<DMatrix>*>(handle);
|
auto *pmat = static_cast<std::shared_ptr<DMatrix>*>(handle);
|
||||||
MetaInfo& info = pmat->get()->Info();
|
MetaInfo& info = pmat->get()->Info();
|
||||||
info.group_ptr_.resize(len + 1);
|
info.group_ptr_.resize(len + 1);
|
||||||
@ -745,6 +751,7 @@ XGB_DLL int XGDMatrixGetFloatInfo(const DMatrixHandle handle,
|
|||||||
xgboost::bst_ulong* out_len,
|
xgboost::bst_ulong* out_len,
|
||||||
const bst_float** out_dptr) {
|
const bst_float** out_dptr) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
CHECK_HANDLE();
|
||||||
const MetaInfo& info = static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info();
|
const MetaInfo& info = static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info();
|
||||||
const std::vector<bst_float>* vec = nullptr;
|
const std::vector<bst_float>* vec = nullptr;
|
||||||
if (!std::strcmp(field, "label")) {
|
if (!std::strcmp(field, "label")) {
|
||||||
@ -766,6 +773,7 @@ XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
|
|||||||
xgboost::bst_ulong *out_len,
|
xgboost::bst_ulong *out_len,
|
||||||
const unsigned **out_dptr) {
|
const unsigned **out_dptr) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
CHECK_HANDLE();
|
||||||
const MetaInfo& info = static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info();
|
const MetaInfo& info = static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info();
|
||||||
const std::vector<unsigned>* vec = nullptr;
|
const std::vector<unsigned>* vec = nullptr;
|
||||||
if (!std::strcmp(field, "root_index")) {
|
if (!std::strcmp(field, "root_index")) {
|
||||||
@ -781,6 +789,7 @@ XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
|
|||||||
XGB_DLL int XGDMatrixNumRow(const DMatrixHandle handle,
|
XGB_DLL int XGDMatrixNumRow(const DMatrixHandle handle,
|
||||||
xgboost::bst_ulong *out) {
|
xgboost::bst_ulong *out) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
CHECK_HANDLE();
|
||||||
*out = static_cast<xgboost::bst_ulong>(
|
*out = static_cast<xgboost::bst_ulong>(
|
||||||
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info().num_row_);
|
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info().num_row_);
|
||||||
API_END();
|
API_END();
|
||||||
@ -789,6 +798,7 @@ XGB_DLL int XGDMatrixNumRow(const DMatrixHandle handle,
|
|||||||
XGB_DLL int XGDMatrixNumCol(const DMatrixHandle handle,
|
XGB_DLL int XGDMatrixNumCol(const DMatrixHandle handle,
|
||||||
xgboost::bst_ulong *out) {
|
xgboost::bst_ulong *out) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
CHECK_HANDLE();
|
||||||
*out = static_cast<size_t>(
|
*out = static_cast<size_t>(
|
||||||
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info().num_col_);
|
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info().num_col_);
|
||||||
API_END();
|
API_END();
|
||||||
@ -809,6 +819,7 @@ XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[],
|
|||||||
|
|
||||||
XGB_DLL int XGBoosterFree(BoosterHandle handle) {
|
XGB_DLL int XGBoosterFree(BoosterHandle handle) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
CHECK_HANDLE();
|
||||||
delete static_cast<Booster*>(handle);
|
delete static_cast<Booster*>(handle);
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
@ -817,6 +828,7 @@ XGB_DLL int XGBoosterSetParam(BoosterHandle handle,
|
|||||||
const char *name,
|
const char *name,
|
||||||
const char *value) {
|
const char *value) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
CHECK_HANDLE();
|
||||||
static_cast<Booster*>(handle)->SetParam(name, value);
|
static_cast<Booster*>(handle)->SetParam(name, value);
|
||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
@ -825,6 +837,7 @@ XGB_DLL int XGBoosterUpdateOneIter(BoosterHandle handle,
|
|||||||
int iter,
|
int iter,
|
||||||
DMatrixHandle dtrain) {
|
DMatrixHandle dtrain) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
CHECK_HANDLE();
|
||||||
auto* bst = static_cast<Booster*>(handle);
|
auto* bst = static_cast<Booster*>(handle);
|
||||||
auto *dtr =
|
auto *dtr =
|
||||||
static_cast<std::shared_ptr<DMatrix>*>(dtrain);
|
static_cast<std::shared_ptr<DMatrix>*>(dtrain);
|
||||||
@ -841,6 +854,7 @@ XGB_DLL int XGBoosterBoostOneIter(BoosterHandle handle,
|
|||||||
xgboost::bst_ulong len) {
|
xgboost::bst_ulong len) {
|
||||||
HostDeviceVector<GradientPair>& tmp_gpair = XGBAPIThreadLocalStore::Get()->tmp_gpair;
|
HostDeviceVector<GradientPair>& tmp_gpair = XGBAPIThreadLocalStore::Get()->tmp_gpair;
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
CHECK_HANDLE();
|
||||||
auto* bst = static_cast<Booster*>(handle);
|
auto* bst = static_cast<Booster*>(handle);
|
||||||
auto* dtr =
|
auto* dtr =
|
||||||
static_cast<std::shared_ptr<DMatrix>*>(dtrain);
|
static_cast<std::shared_ptr<DMatrix>*>(dtrain);
|
||||||
@ -863,6 +877,7 @@ XGB_DLL int XGBoosterEvalOneIter(BoosterHandle handle,
|
|||||||
const char** out_str) {
|
const char** out_str) {
|
||||||
std::string& eval_str = XGBAPIThreadLocalStore::Get()->ret_str;
|
std::string& eval_str = XGBAPIThreadLocalStore::Get()->ret_str;
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
CHECK_HANDLE();
|
||||||
auto* bst = static_cast<Booster*>(handle);
|
auto* bst = static_cast<Booster*>(handle);
|
||||||
std::vector<DMatrix*> data_sets;
|
std::vector<DMatrix*> data_sets;
|
||||||
std::vector<std::string> data_names;
|
std::vector<std::string> data_names;
|
||||||
@ -887,6 +902,7 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
|||||||
HostDeviceVector<bst_float>& preds =
|
HostDeviceVector<bst_float>& preds =
|
||||||
XGBAPIThreadLocalStore::Get()->ret_vec_float;
|
XGBAPIThreadLocalStore::Get()->ret_vec_float;
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
CHECK_HANDLE();
|
||||||
auto *bst = static_cast<Booster*>(handle);
|
auto *bst = static_cast<Booster*>(handle);
|
||||||
bst->LazyInit();
|
bst->LazyInit();
|
||||||
bst->learner()->Predict(
|
bst->learner()->Predict(
|
||||||
@ -904,6 +920,7 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
|||||||
|
|
||||||
XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
|
XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
CHECK_HANDLE();
|
||||||
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
|
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
|
||||||
static_cast<Booster*>(handle)->LoadModel(fi.get());
|
static_cast<Booster*>(handle)->LoadModel(fi.get());
|
||||||
API_END();
|
API_END();
|
||||||
@ -911,6 +928,7 @@ XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
|
|||||||
|
|
||||||
XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char* fname) {
|
XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char* fname) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
CHECK_HANDLE();
|
||||||
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname, "w"));
|
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname, "w"));
|
||||||
auto *bst = static_cast<Booster*>(handle);
|
auto *bst = static_cast<Booster*>(handle);
|
||||||
bst->LazyInit();
|
bst->LazyInit();
|
||||||
@ -922,6 +940,7 @@ XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle,
|
|||||||
const void* buf,
|
const void* buf,
|
||||||
xgboost::bst_ulong len) {
|
xgboost::bst_ulong len) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
CHECK_HANDLE();
|
||||||
common::MemoryFixSizeBuffer fs((void*)buf, len); // NOLINT(*)
|
common::MemoryFixSizeBuffer fs((void*)buf, len); // NOLINT(*)
|
||||||
static_cast<Booster*>(handle)->LoadModel(&fs);
|
static_cast<Booster*>(handle)->LoadModel(&fs);
|
||||||
API_END();
|
API_END();
|
||||||
@ -934,6 +953,7 @@ XGB_DLL int XGBoosterGetModelRaw(BoosterHandle handle,
|
|||||||
raw_str.resize(0);
|
raw_str.resize(0);
|
||||||
|
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
CHECK_HANDLE();
|
||||||
common::MemoryBufferStream fo(&raw_str);
|
common::MemoryBufferStream fo(&raw_str);
|
||||||
auto *bst = static_cast<Booster*>(handle);
|
auto *bst = static_cast<Booster*>(handle);
|
||||||
bst->LazyInit();
|
bst->LazyInit();
|
||||||
@ -976,6 +996,7 @@ XGB_DLL int XGBoosterDumpModelEx(BoosterHandle handle,
|
|||||||
xgboost::bst_ulong* len,
|
xgboost::bst_ulong* len,
|
||||||
const char*** out_models) {
|
const char*** out_models) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
CHECK_HANDLE();
|
||||||
FeatureMap featmap;
|
FeatureMap featmap;
|
||||||
if (strlen(fmap) != 0) {
|
if (strlen(fmap) != 0) {
|
||||||
std::unique_ptr<dmlc::Stream> fs(
|
std::unique_ptr<dmlc::Stream> fs(
|
||||||
@ -1006,6 +1027,7 @@ XGB_DLL int XGBoosterDumpModelExWithFeatures(BoosterHandle handle,
|
|||||||
xgboost::bst_ulong* len,
|
xgboost::bst_ulong* len,
|
||||||
const char*** out_models) {
|
const char*** out_models) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
CHECK_HANDLE();
|
||||||
FeatureMap featmap;
|
FeatureMap featmap;
|
||||||
for (int i = 0; i < fnum; ++i) {
|
for (int i = 0; i < fnum; ++i) {
|
||||||
featmap.PushBack(i, fname[i], ftype[i]);
|
featmap.PushBack(i, fname[i], ftype[i]);
|
||||||
@ -1021,6 +1043,7 @@ XGB_DLL int XGBoosterGetAttr(BoosterHandle handle,
|
|||||||
auto* bst = static_cast<Booster*>(handle);
|
auto* bst = static_cast<Booster*>(handle);
|
||||||
std::string& ret_str = XGBAPIThreadLocalStore::Get()->ret_str;
|
std::string& ret_str = XGBAPIThreadLocalStore::Get()->ret_str;
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
CHECK_HANDLE();
|
||||||
if (bst->learner()->GetAttr(key, &ret_str)) {
|
if (bst->learner()->GetAttr(key, &ret_str)) {
|
||||||
*out = ret_str.c_str();
|
*out = ret_str.c_str();
|
||||||
*success = 1;
|
*success = 1;
|
||||||
@ -1036,6 +1059,7 @@ XGB_DLL int XGBoosterSetAttr(BoosterHandle handle,
|
|||||||
const char* value) {
|
const char* value) {
|
||||||
auto* bst = static_cast<Booster*>(handle);
|
auto* bst = static_cast<Booster*>(handle);
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
CHECK_HANDLE();
|
||||||
if (value == nullptr) {
|
if (value == nullptr) {
|
||||||
bst->learner()->DelAttr(key);
|
bst->learner()->DelAttr(key);
|
||||||
} else {
|
} else {
|
||||||
@ -1051,6 +1075,7 @@ XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle,
|
|||||||
std::vector<const char*>& charp_vecs = XGBAPIThreadLocalStore::Get()->ret_vec_charp;
|
std::vector<const char*>& charp_vecs = XGBAPIThreadLocalStore::Get()->ret_vec_charp;
|
||||||
auto *bst = static_cast<Booster*>(handle);
|
auto *bst = static_cast<Booster*>(handle);
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
CHECK_HANDLE();
|
||||||
str_vecs = bst->learner()->GetAttrNames();
|
str_vecs = bst->learner()->GetAttrNames();
|
||||||
charp_vecs.resize(str_vecs.size());
|
charp_vecs.resize(str_vecs.size());
|
||||||
for (size_t i = 0; i < str_vecs.size(); ++i) {
|
for (size_t i = 0; i < str_vecs.size(); ++i) {
|
||||||
@ -1064,6 +1089,7 @@ XGB_DLL int XGBoosterGetAttrNames(BoosterHandle handle,
|
|||||||
XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
|
XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
|
||||||
int* version) {
|
int* version) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
CHECK_HANDLE();
|
||||||
auto* bst = static_cast<Booster*>(handle);
|
auto* bst = static_cast<Booster*>(handle);
|
||||||
*version = rabit::LoadCheckPoint(bst->learner());
|
*version = rabit::LoadCheckPoint(bst->learner());
|
||||||
if (*version != 0) {
|
if (*version != 0) {
|
||||||
@ -1074,6 +1100,7 @@ XGB_DLL int XGBoosterLoadRabitCheckpoint(BoosterHandle handle,
|
|||||||
|
|
||||||
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
|
XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
|
CHECK_HANDLE();
|
||||||
auto* bst = static_cast<Booster*>(handle);
|
auto* bst = static_cast<Booster*>(handle);
|
||||||
if (bst->learner()->AllowLazyCheckPoint()) {
|
if (bst->learner()->AllowLazyCheckPoint()) {
|
||||||
rabit::LazyCheckPoint(bst->learner());
|
rabit::LazyCheckPoint(bst->learner());
|
||||||
|
|||||||
@ -15,6 +15,8 @@
|
|||||||
/*! \brief every function starts with API_BEGIN();
|
/*! \brief every function starts with API_BEGIN();
|
||||||
and finishes with API_END() or API_END_HANDLE_ERROR */
|
and finishes with API_END() or API_END_HANDLE_ERROR */
|
||||||
#define API_END() } catch(dmlc::Error &_except_) { return XGBAPIHandleException(_except_); } return 0; // NOLINT(*)
|
#define API_END() } catch(dmlc::Error &_except_) { return XGBAPIHandleException(_except_); } return 0; // NOLINT(*)
|
||||||
|
#define CHECK_HANDLE() if (handle == nullptr) \
|
||||||
|
LOG(FATAL) << "DMatrix/Booster has not been intialized or has already been disposed.";
|
||||||
/*!
|
/*!
|
||||||
* \brief every function starts with API_BEGIN();
|
* \brief every function starts with API_BEGIN();
|
||||||
* and finishes with API_END() or API_END_HANDLE_ERROR
|
* and finishes with API_END() or API_END_HANDLE_ERROR
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user