allow standalone random

This commit is contained in:
tqchen 2014-08-31 14:07:44 -07:00
parent ba4f00d55d
commit 168f78623f
5 changed files with 77 additions and 31 deletions

View File

@ -49,7 +49,6 @@ xgb.Booster <- function(params = list(), cachelist = list(), modelfile = NULL) {
} }
} }
handle <- .Call("XGBoosterCreate_R", cachelist, PACKAGE = "xgboost") handle <- .Call("XGBoosterCreate_R", cachelist, PACKAGE = "xgboost")
.Call("XGBoosterSetParam_R", handle, "seed", "0", PACKAGE = "xgboost")
if (length(params) != 0) { if (length(params) != 0) {
for (i in 1:length(params)) { for (i in 1:length(params)) {
p <- params[i] p <- params[i]

View File

@ -1,7 +1,7 @@
# package root # package root
PKGROOT=../../ PKGROOT=../../
# _*_ mode: Makefile; _*_ # _*_ mode: Makefile; _*_
PKG_CPPFLAGS= -DXGBOOST_CUSTOMIZE_MSG_ -I$(PKGROOT) PKG_CPPFLAGS= -DXGBOOST_CUSTOMIZE_MSG_ -DXGBOOST_CUSTOMIZE_PRNG_ -I$(PKGROOT)
PKG_CXXFLAGS= $(SHLIB_OPENMP_CFLAGS) PKG_CXXFLAGS= $(SHLIB_OPENMP_CFLAGS)
PKG_LIBS = $(SHLIB_OPENMP_CFLAGS) PKG_LIBS = $(SHLIB_OPENMP_CFLAGS)

View File

@ -1,8 +1,8 @@
#include "xgboost_R.h"
#include <vector> #include <vector>
#include <string> #include <string>
#include <utility> #include <utility>
#include <cstring> #include <cstring>
#include "xgboost_R.h"
#include "wrapper/xgboost_wrapper.h" #include "wrapper/xgboost_wrapper.h"
#include "src/utils/utils.h" #include "src/utils/utils.h"
#include "src/utils/omp.h" #include "src/utils/omp.h"
@ -22,8 +22,28 @@ void HandlePrint(const char *msg) {
Rprintf("%s", msg); Rprintf("%s", msg);
} }
} // namespace utils } // namespace utils
namespace random {
void Seed(unsigned seed) {
warning("parameter seed is ignored, please set random seed using set.seed");
}
double Uniform(void) {
return unif_rand();
}
double Normal(void) {
return norm_rand();
}
} // namespace random
} // namespace xgboost } // namespace xgboost
// call before wrapper starts
inline void _WrapperBegin(void) {
GetRNGstate();
}
// call after wrapper starts
inline void _WrapperEnd(void) {
PutRNGstate();
}
extern "C" { extern "C" {
void _DMatrixFinalizer(SEXP ext) { void _DMatrixFinalizer(SEXP ext) {
if (R_ExternalPtrAddr(ext) == NULL) return; if (R_ExternalPtrAddr(ext) == NULL) return;
@ -31,14 +51,17 @@ extern "C" {
R_ClearExternalPtr(ext); R_ClearExternalPtr(ext);
} }
SEXP XGDMatrixCreateFromFile_R(SEXP fname, SEXP silent) { SEXP XGDMatrixCreateFromFile_R(SEXP fname, SEXP silent) {
_WrapperBegin();
void *handle = XGDMatrixCreateFromFile(CHAR(asChar(fname)), asInteger(silent)); void *handle = XGDMatrixCreateFromFile(CHAR(asChar(fname)), asInteger(silent));
SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE); R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
UNPROTECT(1); UNPROTECT(1);
_WrapperEnd();
return ret; return ret;
} }
SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP XGDMatrixCreateFromMat_R(SEXP mat,
SEXP missing) { SEXP missing) {
_WrapperBegin();
SEXP dim = getAttrib(mat, R_DimSymbol); SEXP dim = getAttrib(mat, R_DimSymbol);
int nrow = INTEGER(dim)[0]; int nrow = INTEGER(dim)[0];
int ncol = INTEGER(dim)[1]; int ncol = INTEGER(dim)[1];
@ -54,11 +77,13 @@ extern "C" {
SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE); R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
UNPROTECT(1); UNPROTECT(1);
_WrapperEnd();
return ret; return ret;
} }
SEXP XGDMatrixCreateFromCSC_R(SEXP indptr, SEXP XGDMatrixCreateFromCSC_R(SEXP indptr,
SEXP indices, SEXP indices,
SEXP data) { SEXP data) {
_WrapperBegin();
const int *col_ptr = INTEGER(indptr); const int *col_ptr = INTEGER(indptr);
const int *row_index = INTEGER(indices); const int *row_index = INTEGER(indices);
const double *col_data = REAL(data); const double *col_data = REAL(data);
@ -92,9 +117,11 @@ extern "C" {
SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE); R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
UNPROTECT(1); UNPROTECT(1);
_WrapperEnd();
return ret; return ret;
} }
SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) { SEXP XGDMatrixSliceDMatrix_R(SEXP handle, SEXP idxset) {
_WrapperBegin();
int len = length(idxset); int len = length(idxset);
std::vector<int> idxvec(len); std::vector<int> idxvec(len);
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
@ -104,13 +131,17 @@ extern "C" {
SEXP ret = PROTECT(R_MakeExternalPtr(res, R_NilValue, R_NilValue)); SEXP ret = PROTECT(R_MakeExternalPtr(res, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE); R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
UNPROTECT(1); UNPROTECT(1);
_WrapperEnd();
return ret; return ret;
} }
void XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent) { void XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent) {
_WrapperBegin();
XGDMatrixSaveBinary(R_ExternalPtrAddr(handle), XGDMatrixSaveBinary(R_ExternalPtrAddr(handle),
CHAR(asChar(fname)), asInteger(silent)); CHAR(asChar(fname)), asInteger(silent));
_WrapperEnd();
} }
void XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) { void XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
_WrapperBegin();
int len = length(array); int len = length(array);
const char *name = CHAR(asChar(field)); const char *name = CHAR(asChar(field));
if (!strcmp("group", name)) { if (!strcmp("group", name)) {
@ -120,6 +151,7 @@ extern "C" {
vec[i] = static_cast<unsigned>(INTEGER(array)[i]); vec[i] = static_cast<unsigned>(INTEGER(array)[i]);
} }
XGDMatrixSetGroup(R_ExternalPtrAddr(handle), &vec[0], len); XGDMatrixSetGroup(R_ExternalPtrAddr(handle), &vec[0], len);
_WrapperEnd();
return; return;
} }
{ {
@ -132,8 +164,10 @@ extern "C" {
CHAR(asChar(field)), CHAR(asChar(field)),
&vec[0], len); &vec[0], len);
} }
_WrapperEnd();
} }
SEXP XGDMatrixGetInfo_R(SEXP handle, SEXP field) { SEXP XGDMatrixGetInfo_R(SEXP handle, SEXP field) {
_WrapperBegin();
bst_ulong olen; bst_ulong olen;
const float *res = XGDMatrixGetFloatInfo(R_ExternalPtrAddr(handle), const float *res = XGDMatrixGetFloatInfo(R_ExternalPtrAddr(handle),
CHAR(asChar(field)), &olen); CHAR(asChar(field)), &olen);
@ -142,6 +176,7 @@ extern "C" {
REAL(ret)[i] = res[i]; REAL(ret)[i] = res[i];
} }
UNPROTECT(1); UNPROTECT(1);
_WrapperEnd();
return ret; return ret;
} }
// functions related to booster // functions related to booster
@ -151,6 +186,7 @@ extern "C" {
R_ClearExternalPtr(ext); R_ClearExternalPtr(ext);
} }
SEXP XGBoosterCreate_R(SEXP dmats) { SEXP XGBoosterCreate_R(SEXP dmats) {
_WrapperBegin();
int len = length(dmats); int len = length(dmats);
std::vector<void*> dvec; std::vector<void*> dvec;
for (int i = 0; i < len; ++i){ for (int i = 0; i < len; ++i){
@ -160,19 +196,25 @@ extern "C" {
SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue)); SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE); R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
UNPROTECT(1); UNPROTECT(1);
_WrapperEnd();
return ret; return ret;
} }
void XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val) { void XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val) {
_WrapperBegin();
XGBoosterSetParam(R_ExternalPtrAddr(handle), XGBoosterSetParam(R_ExternalPtrAddr(handle),
CHAR(asChar(name)), CHAR(asChar(name)),
CHAR(asChar(val))); CHAR(asChar(val)));
_WrapperEnd();
} }
void XGBoosterUpdateOneIter_R(SEXP handle, SEXP iter, SEXP dtrain) { void XGBoosterUpdateOneIter_R(SEXP handle, SEXP iter, SEXP dtrain) {
_WrapperBegin();
XGBoosterUpdateOneIter(R_ExternalPtrAddr(handle), XGBoosterUpdateOneIter(R_ExternalPtrAddr(handle),
asInteger(iter), asInteger(iter),
R_ExternalPtrAddr(dtrain)); R_ExternalPtrAddr(dtrain));
_WrapperEnd();
} }
void XGBoosterBoostOneIter_R(SEXP handle, SEXP dtrain, SEXP grad, SEXP hess) { void XGBoosterBoostOneIter_R(SEXP handle, SEXP dtrain, SEXP grad, SEXP hess) {
_WrapperBegin();
utils::Check(length(grad) == length(hess), "gradient and hess must have same length"); utils::Check(length(grad) == length(hess), "gradient and hess must have same length");
int len = length(grad); int len = length(grad);
std::vector<float> tgrad(len), thess(len); std::vector<float> tgrad(len), thess(len);
@ -184,8 +226,10 @@ extern "C" {
XGBoosterBoostOneIter(R_ExternalPtrAddr(handle), XGBoosterBoostOneIter(R_ExternalPtrAddr(handle),
R_ExternalPtrAddr(dtrain), R_ExternalPtrAddr(dtrain),
&tgrad[0], &thess[0], len); &tgrad[0], &thess[0], len);
_WrapperEnd();
} }
SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames) { SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames) {
_WrapperBegin();
utils::Check(length(dmats) == length(evnames), "dmats and evnams must have same length"); utils::Check(length(dmats) == length(evnames), "dmats and evnams must have same length");
int len = length(dmats); int len = length(dmats);
std::vector<void*> vec_dmats; std::vector<void*> vec_dmats;
@ -201,8 +245,10 @@ extern "C" {
return mkString(XGBoosterEvalOneIter(R_ExternalPtrAddr(handle), return mkString(XGBoosterEvalOneIter(R_ExternalPtrAddr(handle),
asInteger(iter), asInteger(iter),
&vec_dmats[0], &vec_sptr[0], len)); &vec_dmats[0], &vec_sptr[0], len));
_WrapperEnd();
} }
SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP output_margin) { SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP output_margin) {
_WrapperBegin();
bst_ulong olen; bst_ulong olen;
const float *res = XGBoosterPredict(R_ExternalPtrAddr(handle), const float *res = XGBoosterPredict(R_ExternalPtrAddr(handle),
R_ExternalPtrAddr(dmat), R_ExternalPtrAddr(dmat),
@ -213,15 +259,21 @@ extern "C" {
REAL(ret)[i] = res[i]; REAL(ret)[i] = res[i];
} }
UNPROTECT(1); UNPROTECT(1);
_WrapperEnd();
return ret; return ret;
} }
void XGBoosterLoadModel_R(SEXP handle, SEXP fname) { void XGBoosterLoadModel_R(SEXP handle, SEXP fname) {
_WrapperBegin();
XGBoosterLoadModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname))); XGBoosterLoadModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname)));
_WrapperEnd();
} }
void XGBoosterSaveModel_R(SEXP handle, SEXP fname) { void XGBoosterSaveModel_R(SEXP handle, SEXP fname) {
_WrapperBegin();
XGBoosterSaveModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname))); XGBoosterSaveModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname)));
_WrapperEnd();
} }
void XGBoosterDumpModel_R(SEXP handle, SEXP fname, SEXP fmap) { void XGBoosterDumpModel_R(SEXP handle, SEXP fname, SEXP fmap) {
_WrapperBegin();
bst_ulong olen; bst_ulong olen;
const char **res = XGBoosterDumpModel(R_ExternalPtrAddr(handle), const char **res = XGBoosterDumpModel(R_ExternalPtrAddr(handle),
CHAR(asChar(fmap)), CHAR(asChar(fmap)),
@ -232,5 +284,6 @@ extern "C" {
fprintf(fo, "%s", res[i]); fprintf(fo, "%s", res[i]);
} }
fclose(fo); fclose(fo);
_WrapperEnd();
} }
} }

View File

@ -7,6 +7,7 @@
*/ */
extern "C" { extern "C" {
#include <Rinternals.h> #include <Rinternals.h>
#include <R_ext/Random.h>
} }
extern "C" { extern "C" {

View File

@ -16,30 +16,21 @@
/*! namespace of PRNG */ /*! namespace of PRNG */
namespace xgboost { namespace xgboost {
namespace random { namespace random {
#ifndef XGBOOST_CUSTOMIZE_PRNG_
/*! \brief seed the PRNG */ /*! \brief seed the PRNG */
inline void Seed(uint32_t seed) { inline void Seed(unsigned seed) {
srand(seed); srand(seed);
} }
/*! \brief return a real number uniform in [0,1) */ /*! \brief basic function, uniform */
inline double NextDouble(void) { inline double Uniform(void) {
return static_cast<double>(rand()) / (static_cast<double>(RAND_MAX)+1.0); return static_cast<double>(rand()) / (static_cast<double>(RAND_MAX)+1.0);
} }
/*! \brief return a real numer uniform in (0,1) */ /*! \brief return a real numer uniform in (0,1) */
inline double NextDouble2(void) { inline double NextDouble2(void) {
return (static_cast<double>(rand()) + 1.0) / (static_cast<double>(RAND_MAX)+2.0); return (static_cast<double>(rand()) + 1.0) / (static_cast<double>(RAND_MAX)+2.0);
} }
/*! \brief return a random number */
inline uint32_t NextUInt32(void) {
return (uint32_t)rand();
}
/*! \brief return a random number in n */
inline uint32_t NextUInt32(uint32_t n) {
return (uint32_t)floor(NextDouble() * n);
}
/*! \brief return x~N(0,1) */ /*! \brief return x~N(0,1) */
inline double SampleNormal() { inline double Normal(void) {
double x, y, s; double x, y, s;
do { do {
x = 2 * NextDouble2() - 1.0; x = 2 * NextDouble2() - 1.0;
@ -49,22 +40,24 @@ inline double SampleNormal() {
return x * sqrt(-2.0 * log(s) / s); return x * sqrt(-2.0 * log(s) / s);
} }
#else
// include declarations, to be implemented
void Seed(unsigned seed);
double Uniform(void);
double Normal(void);
#endif
/*! \brief return iid x,y ~N(0,1) */ /*! \brief return a real number uniform in [0,1) */
inline void SampleNormal2D(double &xx, double &yy) { inline double NextDouble(void) {
double x, y, s; return Uniform();
do { }
x = 2 * NextDouble2() - 1.0; /*! \brief return a random number in n */
y = 2 * NextDouble2() - 1.0; inline uint32_t NextUInt32(uint32_t n) {
s = x*x + y*y; return (uint32_t)floor(NextDouble() * n);
} while (s >= 1.0 || s == 0.0);
double t = sqrt(-2.0 * log(s) / s);
xx = x * t;
yy = y * t;
} }
/*! \brief return x~N(mu,sigma^2) */ /*! \brief return x~N(mu,sigma^2) */
inline double SampleNormal(double mu, double sigma) { inline double SampleNormal(double mu, double sigma) {
return SampleNormal() * sigma + mu; return Normal() * sigma + mu;
} }
/*! \brief return 1 with probability p, coin flip */ /*! \brief return 1 with probability p, coin flip */
inline int SampleBinary(double p) { inline int SampleBinary(double p) {
@ -90,7 +83,7 @@ struct Random{
inline void Seed(unsigned sd) { inline void Seed(unsigned sd) {
this->rseed = sd; this->rseed = sd;
#if defined(_MSC_VER)||defined(_WIN32) #if defined(_MSC_VER)||defined(_WIN32)
srand(rseed); ::xgboost::utils::Seed(sd);
#endif #endif
} }
/*! \brief return a real number uniform in [0,1) */ /*! \brief return a real number uniform in [0,1) */
@ -99,7 +92,7 @@ struct Random{
// For cygwin and mingw, this can slows down parallelism, but rand_r is only used in objective-inl.hpp, won't affect speed in general // For cygwin and mingw, this can slows down parallelism, but rand_r is only used in objective-inl.hpp, won't affect speed in general
// todo, replace with another PRNG // todo, replace with another PRNG
#if defined(_MSC_VER)||defined(_WIN32) #if defined(_MSC_VER)||defined(_WIN32)
return static_cast<double>(rand()) / (static_cast<double>(RAND_MAX) + 1.0); return Uniform();
#else #else
return static_cast<double>(rand_r(&rseed)) / (static_cast<double>(RAND_MAX) + 1.0); return static_cast<double>(rand_r(&rseed)) / (static_cast<double>(RAND_MAX) + 1.0);
#endif #endif