add beginPtr, to make vector address taking safe

This commit is contained in:
tqchen 2014-09-02 11:01:38 -07:00
parent 70219ee1ae
commit 27cabd131e
2 changed files with 20 additions and 14 deletions

View File

@ -10,6 +10,7 @@
#include "src/utils/matrix_csr.h" #include "src/utils/matrix_csr.h"
using namespace std; using namespace std;
using namespace xgboost; using namespace xgboost;
using namespace xgboost::utils;
extern "C" { extern "C" {
void XGBoostAssert_R(int exp, const char *fmt, ...); void XGBoostAssert_R(int exp, const char *fmt, ...);
@ -80,7 +81,7 @@ extern "C" {
data[i * ncol +j] = din[i + nrow * j]; data[i * ncol +j] = din[i + nrow * j];
} }
} }
void *handle = XGDMatrixCreateFromMat(&data[0], nrow, ncol, asReal(missing)); void *handle = XGDMatrixCreateFromMat(BeginPtr(data), nrow, ncol, asReal(missing));
SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE); R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
UNPROTECT(1); UNPROTECT(1);
@ -120,7 +121,8 @@ extern "C" {
col_index[i] = csr_data[i].first; col_index[i] = csr_data[i].first;
row_data[i] = csr_data[i].second; row_data[i] = csr_data[i].second;
} }
void *handle = XGDMatrixCreateFromCSR(&row_ptr[0], &col_index[0], &row_data[0], row_ptr.size(), ndata ); void *handle = XGDMatrixCreateFromCSR(BeginPtr(row_ptr), BeginPtr(col_index),
BeginPtr(row_data), row_ptr.size(), ndata );
SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE); R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
UNPROTECT(1); UNPROTECT(1);
@ -134,7 +136,7 @@ extern "C" {
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
idxvec[i] = INTEGER(idxset)[i] - 1; idxvec[i] = INTEGER(idxset)[i] - 1;
} }
void *res = XGDMatrixSliceDMatrix(R_ExternalPtrAddr(handle), &idxvec[0], len); void *res = XGDMatrixSliceDMatrix(R_ExternalPtrAddr(handle), BeginPtr(idxvec), len);
SEXP ret = PROTECT(R_MakeExternalPtr(res, R_NilValue, R_NilValue)); SEXP ret = PROTECT(R_MakeExternalPtr(res, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE); R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
UNPROTECT(1); UNPROTECT(1);
@ -157,7 +159,7 @@ extern "C" {
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
vec[i] = static_cast<unsigned>(INTEGER(array)[i]); vec[i] = static_cast<unsigned>(INTEGER(array)[i]);
} }
XGDMatrixSetGroup(R_ExternalPtrAddr(handle), &vec[0], len); XGDMatrixSetGroup(R_ExternalPtrAddr(handle), BeginPtr(vec), len);
_WrapperEnd(); _WrapperEnd();
return; return;
} }
@ -169,7 +171,7 @@ extern "C" {
} }
XGDMatrixSetFloatInfo(R_ExternalPtrAddr(handle), XGDMatrixSetFloatInfo(R_ExternalPtrAddr(handle),
CHAR(asChar(field)), CHAR(asChar(field)),
&vec[0], len); BeginPtr(vec), len);
} }
_WrapperEnd(); _WrapperEnd();
} }
@ -199,12 +201,7 @@ extern "C" {
for (int i = 0; i < len; ++i){ for (int i = 0; i < len; ++i){
dvec.push_back(R_ExternalPtrAddr(VECTOR_ELT(dmats, i))); dvec.push_back(R_ExternalPtrAddr(VECTOR_ELT(dmats, i)));
} }
void *handle; void *handle = XGBoosterCreate(BeginPtr(dvec), dvec.size());
if (dvec.size() == 0) {
handle = XGBoosterCreate(NULL, 0);
} else {
handle = XGBoosterCreate(&dvec[0], dvec.size());
}
SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE); R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
UNPROTECT(1); UNPROTECT(1);
@ -237,7 +234,7 @@ extern "C" {
} }
XGBoosterBoostOneIter(R_ExternalPtrAddr(handle), XGBoosterBoostOneIter(R_ExternalPtrAddr(handle),
R_ExternalPtrAddr(dtrain), R_ExternalPtrAddr(dtrain),
&tgrad[0], &thess[0], len); BeginPtr(tgrad), BeginPtr(thess), len);
_WrapperEnd(); _WrapperEnd();
} }
SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames) { SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames) {
@ -256,7 +253,7 @@ extern "C" {
} }
return mkString(XGBoosterEvalOneIter(R_ExternalPtrAddr(handle), return mkString(XGBoosterEvalOneIter(R_ExternalPtrAddr(handle),
asInteger(iter), asInteger(iter),
&vec_dmats[0], &vec_sptr[0], len)); BeginPtr(vec_dmats), BeginPtr(vec_sptr), len));
_WrapperEnd(); _WrapperEnd();
} }
SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP output_margin, SEXP ntree_limit) { SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP output_margin, SEXP ntree_limit) {

View File

@ -9,6 +9,7 @@
#include <cstdio> #include <cstdio>
#include <string> #include <string>
#include <cstdlib> #include <cstdlib>
#include <vector>
#ifndef XGBOOST_STRICT_CXX98_ #ifndef XGBOOST_STRICT_CXX98_
#include <cstdarg> #include <cstdarg>
@ -153,7 +154,15 @@ inline FILE *FopenCheck(const char *fname, const char *flag) {
Check(fp != NULL, "can not open file \"%s\"\n", fname); Check(fp != NULL, "can not open file \"%s\"\n", fname);
return fp; return fp;
} }
/*! \brief get the beginning address of a vector */
template<typename T>
inline T *BeginPtr(std::vector<T> &vec) {
if (vec.size() == 0) {
return NULL;
} else {
return &vec[0];
}
}
} // namespace utils } // namespace utils
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_UTILS_UTILS_H_ #endif // XGBOOST_UTILS_UTILS_H_