workable R wrapper

This commit is contained in:
tqchen 2014-08-23 12:14:44 -07:00
parent 5e23f6577f
commit 40da2fa2c0
5 changed files with 164 additions and 9 deletions

View File

@ -11,14 +11,15 @@ endif
# specify tensor path # specify tensor path
BIN = xgboost BIN = xgboost
OBJ = OBJ =
SLIB = python/libxgboostwrapper.so SLIB = python/libxgboostR.so
.PHONY: clean all .PHONY: clean all
all: $(BIN) $(OBJ) $(SLIB) all: $(BIN) $(OBJ) $(SLIB)
xgboost: src/xgboost_main.cpp src/io/io.cpp src/data.h src/tree/*.h src/tree/*.hpp src/gbm/*.h src/gbm/*.hpp src/utils/*.h src/learner/*.h src/learner/*.hpp xgboost: src/xgboost_main.cpp src/io/io.cpp src/data.h src/tree/*.h src/tree/*.hpp src/gbm/*.h src/gbm/*.hpp src/utils/*.h src/learner/*.h src/learner/*.hpp
# now the wrapper takes in two files. io and wrapper part # now the wrapper takes in two files. io and wrapper part
python/libxgboostwrapper.so: python/xgboost_wrapper.cpp src/io/io.cpp src/*.h src/*/*.hpp src/*/*.h #python/libxgboostwrapper.so: python/xgboost_wrapper.cpp src/io/io.cpp src/*.h src/*/*.hpp src/*/*.h
python/libxgboostR.so: python/xgboost_wrapper_R.cpp python/xgboost_wrapper.cpp src/io/io.cpp src/*.h src/*/*.hpp src/*/*.h
$(BIN) : $(BIN) :
$(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c, $^) $(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c, $^)

View File

@ -1,2 +1,82 @@
#su # load in library
dyn.load() dyn.load("libxgboostR.so")
# constructing DMatrix
xgb.DMatrix <- function(data) {
if (typeof(data) == "character") {
handle <- .Call("XGDMatrixCreateFromFile_R", data)
} else {
stop("xgb.DMatrix cannot recognize data type")
}
return(structure(handle, class="xgb.DMatrix"))
}
# construct a Booster from cachelist
xgb.Booster <- function(cachelist, params) {
if (typeof(cachelist) != "list") {
stop("xgb.Booster: only accepts list of DMatrix as cachelist")
}
for (dm in cachelist) {
if (class(dm) != "xgb.DMatrix") {
stop("xgb.Booster: only accepts list of DMatrix as cachelist")
}
}
handle <- .Call("XGBoosterCreate_R", cachelist)
.Call("XGBoosterSetParam_R", handle, "silent", "1")
for (i in 1:length(params)) {
p = params[i]
.Call("XGBoosterSetParam_R", handle, names(p), as.character(p))
}
return(structure(handle, class="xgb.Booster"))
}
# update booster with dtrain
xgb.update <- function(booster, dtrain, iter) {
if (class(booster) != "xgb.Booster") {
stop("xgb.update: first argument must be type xgb.Booster")
}
if (class(dtrain) != "xgb.DMatrix") {
stop("xgb.update: second argument must be type xgb.DMatrix")
}
.Call("XGBoosterUpdateOneIter_R", booster, as.integer(iter), dtrain)
return(TRUE)
}
# evaluate one iteration
xgb.eval <- function(booster, watchlist, iter) {
if (class(booster) != "xgb.Booster") {
stop("xgb.eval: first argument must be type xgb.Booster")
}
if (typeof(watchlist) != "list") {
stop("xgb.eval: only accepts list of DMatrix as watchlist")
}
for (w in watchlist) {
if (class(w) != "xgb.DMatrix") {
stop("xgb.eval: watch list can only contain xgb.DMatrix")
}
}
evnames <- list()
for (i in 1:length(watchlist)) {
w <- watchlist[i]
if (length(names(w)) == 0) {
stop("xgb.eval: name tag must be presented for every elements in watchlist")
}
evnames <- append(evnames, names(w))
}
msg <- .Call("XGBoosterEvalOneIter_R", booster, as.integer(iter), watchlist, evnames)
return(msg)
}
# test code here
dtrain <- xgb.DMatrix("example/agaricus.txt.train")
dtest <- xgb.DMatrix("example/agaricus.txt.test")
param <- list("bst:min_child_weight" = 10,
"objective" = "binary:logistic"
)
bst<- xgb.Booster(list(dtrain, dtest), param )
success <- xgb.update(bst, dtrain, 0)
watchlist <- list('train'=dtrain,'test'=dtest)
cat(xgb.eval(bst, watchlist, 0))
cat("\n")

View File

@ -114,7 +114,7 @@ extern "C" {
* \param handle handle * \param handle handle
* \param iter current iteration rounds * \param iter current iteration rounds
* \param dtrain training data * \param dtrain training data
*/ */
void XGBoosterUpdateOneIter(void *handle, int iter, void *dtrain); void XGBoosterUpdateOneIter(void *handle, int iter, void *dtrain);
/*! /*!
* \brief update the model, by directly specify gradient and second order gradient, * \brief update the model, by directly specify gradient and second order gradient,
@ -127,7 +127,7 @@ extern "C" {
*/ */
void XGBoosterBoostOneIter(void *handle, void *dtrain, void XGBoosterBoostOneIter(void *handle, void *dtrain,
float *grad, float *hess, size_t len); float *grad, float *hess, size_t len);
/*! /*!
* \brief get evaluation statistics for xgboost * \brief get evaluation statistics for xgboost
* \param handle handle * \param handle handle
* \param iter current iteration rounds * \param iter current iteration rounds
@ -135,7 +135,7 @@ extern "C" {
* \param evnames pointers to names of each data * \param evnames pointers to names of each data
* \param len length of dmats * \param len length of dmats
* \return the string containing evaluation stati * \return the string containing evaluation stati
*/ */
const char *XGBoosterEvalOneIter(void *handle, int iter, void *dmats[], const char *XGBoosterEvalOneIter(void *handle, int iter, void *dmats[],
const char *evnames[], size_t len); const char *evnames[], size_t len);
/*! /*!

View File

@ -1,5 +1,7 @@
#include "xgboost_wrapper_R.h" #include <vector>
#include <string>
#include "xgboost_wrapper.h" #include "xgboost_wrapper.h"
#include "xgboost_wrapper_R.h"
#include "../src/utils/utils.h" #include "../src/utils/utils.h"
using namespace xgboost; using namespace xgboost;
@ -16,4 +18,48 @@ extern "C" {
UNPROTECT(1); UNPROTECT(1);
return ret; return ret;
} }
// functions related to booster
void _BoosterFinalizer(SEXP ext) {
if (R_ExternalPtrAddr(ext) == NULL) return;
XGBoosterFree(R_ExternalPtrAddr(ext));
R_ClearExternalPtr(ext);
}
SEXP XGBoosterCreate_R(SEXP dmats) {
int len = length(dmats);
std::vector<void*> dvec;
for (int i = 0; i < len; ++i){
dvec.push_back(R_ExternalPtrAddr(VECTOR_ELT(dmats, i)));
}
void *handle = XGBoosterCreate(&dvec[0], dvec.size());
SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
UNPROTECT(1);
return ret;
}
void XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val) {
XGBoosterSetParam(R_ExternalPtrAddr(handle),
CHAR(asChar(name)),
CHAR(asChar(val)));
}
void XGBoosterUpdateOneIter_R(SEXP handle, SEXP iter, SEXP dtrain) {
XGBoosterUpdateOneIter(R_ExternalPtrAddr(handle),
asInteger(iter),
R_ExternalPtrAddr(dtrain));
}
SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames) {
utils::Check(length(dmats) == length(evnames), "dmats and evnams must have same length");
int len = length(dmats);
std::vector<void*> vec_dmats;
std::vector<std::string> vec_names;
std::vector<const char*> vec_sptr;
for (int i = 0; i < len; ++i){
vec_dmats.push_back(R_ExternalPtrAddr(VECTOR_ELT(dmats, i)));
vec_names.push_back(std::string(CHAR(asChar(VECTOR_ELT(evnames, i)))));
vec_sptr.push_back(vec_names.back().c_str());
}
return mkString(XGBoosterEvalOneIter(R_ExternalPtrAddr(handle),
asInteger(iter),
&vec_dmats[0], &vec_sptr[0], len));
}
} }

View File

@ -12,9 +12,37 @@ extern "C" {
extern "C" { extern "C" {
/*! /*!
* \brief load a data matrix * \brief load a data matrix
* \fname name of the content
* \return a loaded data matrix * \return a loaded data matrix
*/ */
SEXP XGDMatrixCreateFromFile_R(SEXP fname); SEXP XGDMatrixCreateFromFile_R(SEXP fname);
/*!
* \brief create xgboost learner
* \param dmats a list of dmatrix handles that will be cached
*/
SEXP XGBoosterCreate_R(SEXP dmats);
/*!
* \brief set parameters
* \param handle handle
* \param name parameter name
* \param val value of parameter
*/
void XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val);
/*!
* \brief update the model in one round using dtrain
* \param handle handle
* \param iter current iteration rounds
* \param dtrain training data
*/
void XGBoosterUpdateOneIter_R(SEXP ext, SEXP iter, SEXP dtrain);
/*!
* \brief get evaluation statistics for xgboost
* \param handle handle
* \param iter current iteration rounds
* \param dmats list of handles to dmatrices
* \param evname name of evaluation
* \return the string containing evaluation stati
*/
SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames);
}; };
#endif // XGBOOST_WRAPPER_H_ #endif // XGBOOST_WRAPPER_H_