API refactor to make fault handling easy

This commit is contained in:
tqchen
2015-07-04 18:12:44 -07:00
parent 4d436a3cb0
commit cc767add88
6 changed files with 725 additions and 394 deletions

View File

@@ -59,6 +59,10 @@ inline void _WrapperEnd(void) {
PutRNGstate();
}
// do nothing, check error
inline void CheckErr(int ret) {
}
extern "C" {
SEXP XGCheckNullPtr_R(SEXP handle) {
return ScalarLogical(R_ExternalPtrAddr(handle) == NULL);
@@ -70,7 +74,8 @@ extern "C" {
}
SEXP XGDMatrixCreateFromFile_R(SEXP fname, SEXP silent) {
_WrapperBegin();
void *handle = XGDMatrixCreateFromFile(CHAR(asChar(fname)), asInteger(silent));
DMatrixHandle handle;
CheckErr(XGDMatrixCreateFromFile(CHAR(asChar(fname)), asInteger(silent), &handle));
_WrapperEnd();
SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
@@ -91,7 +96,8 @@ extern "C" {
data[i * ncol +j] = din[i + nrow * j];
}
}
void *handle = XGDMatrixCreateFromMat(BeginPtr(data), nrow, ncol, asReal(missing));
DMatrixHandle handle;
CheckErr(XGDMatrixCreateFromMat(BeginPtr(data), nrow, ncol, asReal(missing), &handle));
_WrapperEnd();
SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
@@ -119,8 +125,10 @@ extern "C" {
indices_[i] = static_cast<unsigned>(p_indices[i]);
data_[i] = static_cast<float>(p_data[i]);
}
void *handle = XGDMatrixCreateFromCSC(BeginPtr(col_ptr_), BeginPtr(indices_),
BeginPtr(data_), nindptr, ndata);
DMatrixHandle handle;
CheckErr(XGDMatrixCreateFromCSC(BeginPtr(col_ptr_), BeginPtr(indices_),
BeginPtr(data_), nindptr, ndata,
&handle));
_WrapperEnd();
SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
@@ -134,7 +142,10 @@ extern "C" {
for (int i = 0; i < len; ++i) {
idxvec[i] = INTEGER(idxset)[i] - 1;
}
void *res = XGDMatrixSliceDMatrix(R_ExternalPtrAddr(handle), BeginPtr(idxvec), len);
DMatrixHandle res;
CheckErr(XGDMatrixSliceDMatrix(R_ExternalPtrAddr(handle),
BeginPtr(idxvec), len,
&res));
_WrapperEnd();
SEXP ret = PROTECT(R_MakeExternalPtr(res, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
@@ -143,8 +154,8 @@ extern "C" {
}
void XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent) {
_WrapperBegin();
XGDMatrixSaveBinary(R_ExternalPtrAddr(handle),
CHAR(asChar(fname)), asInteger(silent));
CheckErr(XGDMatrixSaveBinary(R_ExternalPtrAddr(handle),
CHAR(asChar(fname)), asInteger(silent)));
_WrapperEnd();
}
void XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
@@ -157,24 +168,27 @@ extern "C" {
for (int i = 0; i < len; ++i) {
vec[i] = static_cast<unsigned>(INTEGER(array)[i]);
}
XGDMatrixSetGroup(R_ExternalPtrAddr(handle), BeginPtr(vec), len);
CheckErr(XGDMatrixSetGroup(R_ExternalPtrAddr(handle), BeginPtr(vec), len));
} else {
std::vector<float> vec(len);
#pragma omp parallel for schedule(static)
for (int i = 0; i < len; ++i) {
vec[i] = REAL(array)[i];
}
XGDMatrixSetFloatInfo(R_ExternalPtrAddr(handle),
CHAR(asChar(field)),
BeginPtr(vec), len);
CheckErr(XGDMatrixSetFloatInfo(R_ExternalPtrAddr(handle),
CHAR(asChar(field)),
BeginPtr(vec), len));
}
_WrapperEnd();
}
SEXP XGDMatrixGetInfo_R(SEXP handle, SEXP field) {
_WrapperBegin();
bst_ulong olen;
const float *res = XGDMatrixGetFloatInfo(R_ExternalPtrAddr(handle),
CHAR(asChar(field)), &olen);
const float *res;
CheckErr(XGDMatrixGetFloatInfo(R_ExternalPtrAddr(handle),
CHAR(asChar(field)),
&olen,
&res));
_WrapperEnd();
SEXP ret = PROTECT(allocVector(REALSXP, olen));
for (size_t i = 0; i < olen; ++i) {
@@ -184,13 +198,14 @@ extern "C" {
return ret;
}
SEXP XGDMatrixNumRow_R(SEXP handle) {
bst_ulong nrow = XGDMatrixNumRow(R_ExternalPtrAddr(handle));
bst_ulong nrow;
CheckErr(XGDMatrixNumRow(R_ExternalPtrAddr(handle), &nrow));
return ScalarInteger(static_cast<int>(nrow));
}
// functions related to booster
void _BoosterFinalizer(SEXP ext) {
if (R_ExternalPtrAddr(ext) == NULL) return;
XGBoosterFree(R_ExternalPtrAddr(ext));
CheckErr(XGBoosterFree(R_ExternalPtrAddr(ext)));
R_ClearExternalPtr(ext);
}
SEXP XGBoosterCreate_R(SEXP dmats) {
@@ -200,7 +215,8 @@ extern "C" {
for (int i = 0; i < len; ++i) {
dvec.push_back(R_ExternalPtrAddr(VECTOR_ELT(dmats, i)));
}
void *handle = XGBoosterCreate(BeginPtr(dvec), dvec.size());
BoosterHandle handle;
CheckErr(XGBoosterCreate(BeginPtr(dvec), dvec.size(), &handle));
_WrapperEnd();
SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
@@ -209,16 +225,16 @@ extern "C" {
}
void XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val) {
_WrapperBegin();
XGBoosterSetParam(R_ExternalPtrAddr(handle),
CHAR(asChar(name)),
CHAR(asChar(val)));
CheckErr(XGBoosterSetParam(R_ExternalPtrAddr(handle),
CHAR(asChar(name)),
CHAR(asChar(val))));
_WrapperEnd();
}
void XGBoosterUpdateOneIter_R(SEXP handle, SEXP iter, SEXP dtrain) {
_WrapperBegin();
XGBoosterUpdateOneIter(R_ExternalPtrAddr(handle),
asInteger(iter),
R_ExternalPtrAddr(dtrain));
CheckErr(XGBoosterUpdateOneIter(R_ExternalPtrAddr(handle),
asInteger(iter),
R_ExternalPtrAddr(dtrain)));
_WrapperEnd();
}
void XGBoosterBoostOneIter_R(SEXP handle, SEXP dtrain, SEXP grad, SEXP hess) {
@@ -231,9 +247,10 @@ extern "C" {
tgrad[j] = REAL(grad)[j];
thess[j] = REAL(hess)[j];
}
XGBoosterBoostOneIter(R_ExternalPtrAddr(handle),
R_ExternalPtrAddr(dtrain),
BeginPtr(tgrad), BeginPtr(thess), len);
CheckErr(XGBoosterBoostOneIter(R_ExternalPtrAddr(handle),
R_ExternalPtrAddr(dtrain),
BeginPtr(tgrad), BeginPtr(thess),
len));
_WrapperEnd();
}
SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames) {
@@ -250,21 +267,24 @@ extern "C" {
for (int i = 0; i < len; ++i) {
vec_sptr.push_back(vec_names[i].c_str());
}
const char *ret =
XGBoosterEvalOneIter(R_ExternalPtrAddr(handle),
asInteger(iter),
BeginPtr(vec_dmats), BeginPtr(vec_sptr), len);
const char *ret;
CheckErr(XGBoosterEvalOneIter(R_ExternalPtrAddr(handle),
asInteger(iter),
BeginPtr(vec_dmats),
BeginPtr(vec_sptr),
len, &ret));
_WrapperEnd();
return mkString(ret);
}
SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP option_mask, SEXP ntree_limit) {
_WrapperBegin();
bst_ulong olen;
const float *res = XGBoosterPredict(R_ExternalPtrAddr(handle),
R_ExternalPtrAddr(dmat),
asInteger(option_mask),
asInteger(ntree_limit),
&olen);
const float *res;
CheckErr(XGBoosterPredict(R_ExternalPtrAddr(handle),
R_ExternalPtrAddr(dmat),
asInteger(option_mask),
asInteger(ntree_limit),
&olen, &res));
_WrapperEnd();
SEXP ret = PROTECT(allocVector(REALSXP, olen));
for (size_t i = 0; i < olen; ++i) {
@@ -275,12 +295,12 @@ extern "C" {
}
void XGBoosterLoadModel_R(SEXP handle, SEXP fname) {
_WrapperBegin();
XGBoosterLoadModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname)));
CheckErr(XGBoosterLoadModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname))));
_WrapperEnd();
}
void XGBoosterSaveModel_R(SEXP handle, SEXP fname) {
_WrapperBegin();
XGBoosterSaveModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname)));
CheckErr(XGBoosterSaveModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname))));
_WrapperEnd();
}
void XGBoosterLoadModelFromRaw_R(SEXP handle, SEXP raw) {
@@ -293,7 +313,8 @@ extern "C" {
SEXP XGBoosterModelToRaw_R(SEXP handle) {
bst_ulong olen;
_WrapperBegin();
const char *raw = XGBoosterGetModelRaw(R_ExternalPtrAddr(handle), &olen);
const char *raw;
CheckErr(XGBoosterGetModelRaw(R_ExternalPtrAddr(handle), &olen, &raw));
_WrapperEnd();
SEXP ret = PROTECT(allocVector(RAWSXP, olen));
if (olen != 0) {
@@ -305,11 +326,11 @@ extern "C" {
SEXP XGBoosterDumpModel_R(SEXP handle, SEXP fmap, SEXP with_stats) {
_WrapperBegin();
bst_ulong olen;
const char **res =
XGBoosterDumpModel(R_ExternalPtrAddr(handle),
CHAR(asChar(fmap)),
asInteger(with_stats),
&olen);
const char **res;
CheckErr(XGBoosterDumpModel(R_ExternalPtrAddr(handle),
CHAR(asChar(fmap)),
asInteger(with_stats),
&olen, &res));
_WrapperEnd();
SEXP out = PROTECT(allocVector(STRSXP, olen));
for (size_t i = 0; i < olen; ++i) {