workable R wrapper
This commit is contained in:
parent
5e23f6577f
commit
40da2fa2c0
5
Makefile
5
Makefile
@ -11,14 +11,15 @@ endif
|
|||||||
# specify tensor path
|
# specify tensor path
|
||||||
BIN = xgboost
|
BIN = xgboost
|
||||||
OBJ =
|
OBJ =
|
||||||
SLIB = python/libxgboostwrapper.so
|
SLIB = python/libxgboostR.so
|
||||||
.PHONY: clean all
|
.PHONY: clean all
|
||||||
|
|
||||||
all: $(BIN) $(OBJ) $(SLIB)
|
all: $(BIN) $(OBJ) $(SLIB)
|
||||||
|
|
||||||
xgboost: src/xgboost_main.cpp src/io/io.cpp src/data.h src/tree/*.h src/tree/*.hpp src/gbm/*.h src/gbm/*.hpp src/utils/*.h src/learner/*.h src/learner/*.hpp
|
xgboost: src/xgboost_main.cpp src/io/io.cpp src/data.h src/tree/*.h src/tree/*.hpp src/gbm/*.h src/gbm/*.hpp src/utils/*.h src/learner/*.h src/learner/*.hpp
|
||||||
# now the wrapper takes in two files. io and wrapper part
|
# now the wrapper takes in two files. io and wrapper part
|
||||||
python/libxgboostwrapper.so: python/xgboost_wrapper.cpp src/io/io.cpp src/*.h src/*/*.hpp src/*/*.h
|
#python/libxgboostwrapper.so: python/xgboost_wrapper.cpp src/io/io.cpp src/*.h src/*/*.hpp src/*/*.h
|
||||||
|
python/libxgboostR.so: python/xgboost_wrapper_R.cpp python/xgboost_wrapper.cpp src/io/io.cpp src/*.h src/*/*.hpp src/*/*.h
|
||||||
|
|
||||||
$(BIN) :
|
$(BIN) :
|
||||||
$(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c, $^)
|
$(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c, $^)
|
||||||
|
|||||||
@ -1,2 +1,82 @@
|
|||||||
#su
|
# load in library
|
||||||
dyn.load()
|
dyn.load("libxgboostR.so")
|
||||||
|
|
||||||
|
# constructing DMatrix
|
||||||
|
xgb.DMatrix <- function(data) {
|
||||||
|
if (typeof(data) == "character") {
|
||||||
|
handle <- .Call("XGDMatrixCreateFromFile_R", data)
|
||||||
|
} else {
|
||||||
|
stop("xgb.DMatrix cannot recognize data type")
|
||||||
|
}
|
||||||
|
return(structure(handle, class="xgb.DMatrix"))
|
||||||
|
}
|
||||||
|
|
||||||
|
# construct a Booster from cachelist
|
||||||
|
xgb.Booster <- function(cachelist, params) {
|
||||||
|
if (typeof(cachelist) != "list") {
|
||||||
|
stop("xgb.Booster: only accepts list of DMatrix as cachelist")
|
||||||
|
}
|
||||||
|
for (dm in cachelist) {
|
||||||
|
if (class(dm) != "xgb.DMatrix") {
|
||||||
|
stop("xgb.Booster: only accepts list of DMatrix as cachelist")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
handle <- .Call("XGBoosterCreate_R", cachelist)
|
||||||
|
.Call("XGBoosterSetParam_R", handle, "silent", "1")
|
||||||
|
for (i in 1:length(params)) {
|
||||||
|
p = params[i]
|
||||||
|
.Call("XGBoosterSetParam_R", handle, names(p), as.character(p))
|
||||||
|
}
|
||||||
|
return(structure(handle, class="xgb.Booster"))
|
||||||
|
}
|
||||||
|
|
||||||
|
# update booster with dtrain
|
||||||
|
xgb.update <- function(booster, dtrain, iter) {
|
||||||
|
if (class(booster) != "xgb.Booster") {
|
||||||
|
stop("xgb.update: first argument must be type xgb.Booster")
|
||||||
|
}
|
||||||
|
if (class(dtrain) != "xgb.DMatrix") {
|
||||||
|
stop("xgb.update: second argument must be type xgb.DMatrix")
|
||||||
|
}
|
||||||
|
.Call("XGBoosterUpdateOneIter_R", booster, as.integer(iter), dtrain)
|
||||||
|
return(TRUE)
|
||||||
|
}
|
||||||
|
# evaluate one iteration
|
||||||
|
xgb.eval <- function(booster, watchlist, iter) {
|
||||||
|
if (class(booster) != "xgb.Booster") {
|
||||||
|
stop("xgb.eval: first argument must be type xgb.Booster")
|
||||||
|
}
|
||||||
|
if (typeof(watchlist) != "list") {
|
||||||
|
stop("xgb.eval: only accepts list of DMatrix as watchlist")
|
||||||
|
}
|
||||||
|
for (w in watchlist) {
|
||||||
|
if (class(w) != "xgb.DMatrix") {
|
||||||
|
stop("xgb.eval: watch list can only contain xgb.DMatrix")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
evnames <- list()
|
||||||
|
for (i in 1:length(watchlist)) {
|
||||||
|
w <- watchlist[i]
|
||||||
|
if (length(names(w)) == 0) {
|
||||||
|
stop("xgb.eval: name tag must be presented for every elements in watchlist")
|
||||||
|
}
|
||||||
|
evnames <- append(evnames, names(w))
|
||||||
|
}
|
||||||
|
msg <- .Call("XGBoosterEvalOneIter_R", booster, as.integer(iter), watchlist, evnames)
|
||||||
|
return(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# test code here
|
||||||
|
dtrain <- xgb.DMatrix("example/agaricus.txt.train")
|
||||||
|
dtest <- xgb.DMatrix("example/agaricus.txt.test")
|
||||||
|
param <- list("bst:min_child_weight" = 10,
|
||||||
|
"objective" = "binary:logistic"
|
||||||
|
)
|
||||||
|
bst<- xgb.Booster(list(dtrain, dtest), param )
|
||||||
|
success <- xgb.update(bst, dtrain, 0)
|
||||||
|
watchlist <- list('train'=dtrain,'test'=dtest)
|
||||||
|
cat(xgb.eval(bst, watchlist, 0))
|
||||||
|
cat("\n")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -114,7 +114,7 @@ extern "C" {
|
|||||||
* \param handle handle
|
* \param handle handle
|
||||||
* \param iter current iteration rounds
|
* \param iter current iteration rounds
|
||||||
* \param dtrain training data
|
* \param dtrain training data
|
||||||
*/
|
*/
|
||||||
void XGBoosterUpdateOneIter(void *handle, int iter, void *dtrain);
|
void XGBoosterUpdateOneIter(void *handle, int iter, void *dtrain);
|
||||||
/*!
|
/*!
|
||||||
* \brief update the model, by directly specify gradient and second order gradient,
|
* \brief update the model, by directly specify gradient and second order gradient,
|
||||||
@ -127,7 +127,7 @@ extern "C" {
|
|||||||
*/
|
*/
|
||||||
void XGBoosterBoostOneIter(void *handle, void *dtrain,
|
void XGBoosterBoostOneIter(void *handle, void *dtrain,
|
||||||
float *grad, float *hess, size_t len);
|
float *grad, float *hess, size_t len);
|
||||||
/*!
|
/*!
|
||||||
* \brief get evaluation statistics for xgboost
|
* \brief get evaluation statistics for xgboost
|
||||||
* \param handle handle
|
* \param handle handle
|
||||||
* \param iter current iteration rounds
|
* \param iter current iteration rounds
|
||||||
@ -135,7 +135,7 @@ extern "C" {
|
|||||||
* \param evnames pointers to names of each data
|
* \param evnames pointers to names of each data
|
||||||
* \param len length of dmats
|
* \param len length of dmats
|
||||||
* \return the string containing evaluation stati
|
* \return the string containing evaluation stati
|
||||||
*/
|
*/
|
||||||
const char *XGBoosterEvalOneIter(void *handle, int iter, void *dmats[],
|
const char *XGBoosterEvalOneIter(void *handle, int iter, void *dmats[],
|
||||||
const char *evnames[], size_t len);
|
const char *evnames[], size_t len);
|
||||||
/*!
|
/*!
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
#include "xgboost_wrapper_R.h"
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
#include "xgboost_wrapper.h"
|
#include "xgboost_wrapper.h"
|
||||||
|
#include "xgboost_wrapper_R.h"
|
||||||
#include "../src/utils/utils.h"
|
#include "../src/utils/utils.h"
|
||||||
using namespace xgboost;
|
using namespace xgboost;
|
||||||
|
|
||||||
@ -16,4 +18,48 @@ extern "C" {
|
|||||||
UNPROTECT(1);
|
UNPROTECT(1);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// functions related to booster
|
||||||
|
void _BoosterFinalizer(SEXP ext) {
|
||||||
|
if (R_ExternalPtrAddr(ext) == NULL) return;
|
||||||
|
XGBoosterFree(R_ExternalPtrAddr(ext));
|
||||||
|
R_ClearExternalPtr(ext);
|
||||||
|
}
|
||||||
|
SEXP XGBoosterCreate_R(SEXP dmats) {
|
||||||
|
int len = length(dmats);
|
||||||
|
std::vector<void*> dvec;
|
||||||
|
for (int i = 0; i < len; ++i){
|
||||||
|
dvec.push_back(R_ExternalPtrAddr(VECTOR_ELT(dmats, i)));
|
||||||
|
}
|
||||||
|
void *handle = XGBoosterCreate(&dvec[0], dvec.size());
|
||||||
|
SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
|
||||||
|
R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
|
||||||
|
UNPROTECT(1);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
void XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val) {
|
||||||
|
XGBoosterSetParam(R_ExternalPtrAddr(handle),
|
||||||
|
CHAR(asChar(name)),
|
||||||
|
CHAR(asChar(val)));
|
||||||
|
}
|
||||||
|
void XGBoosterUpdateOneIter_R(SEXP handle, SEXP iter, SEXP dtrain) {
|
||||||
|
XGBoosterUpdateOneIter(R_ExternalPtrAddr(handle),
|
||||||
|
asInteger(iter),
|
||||||
|
R_ExternalPtrAddr(dtrain));
|
||||||
|
}
|
||||||
|
SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames) {
|
||||||
|
utils::Check(length(dmats) == length(evnames), "dmats and evnams must have same length");
|
||||||
|
int len = length(dmats);
|
||||||
|
std::vector<void*> vec_dmats;
|
||||||
|
std::vector<std::string> vec_names;
|
||||||
|
std::vector<const char*> vec_sptr;
|
||||||
|
for (int i = 0; i < len; ++i){
|
||||||
|
vec_dmats.push_back(R_ExternalPtrAddr(VECTOR_ELT(dmats, i)));
|
||||||
|
vec_names.push_back(std::string(CHAR(asChar(VECTOR_ELT(evnames, i)))));
|
||||||
|
vec_sptr.push_back(vec_names.back().c_str());
|
||||||
|
}
|
||||||
|
return mkString(XGBoosterEvalOneIter(R_ExternalPtrAddr(handle),
|
||||||
|
asInteger(iter),
|
||||||
|
&vec_dmats[0], &vec_sptr[0], len));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -12,9 +12,37 @@ extern "C" {
|
|||||||
extern "C" {
|
extern "C" {
|
||||||
/*!
|
/*!
|
||||||
* \brief load a data matrix
|
* \brief load a data matrix
|
||||||
|
* \fname name of the content
|
||||||
* \return a loaded data matrix
|
* \return a loaded data matrix
|
||||||
*/
|
*/
|
||||||
SEXP XGDMatrixCreateFromFile_R(SEXP fname);
|
SEXP XGDMatrixCreateFromFile_R(SEXP fname);
|
||||||
|
/*!
|
||||||
|
* \brief create xgboost learner
|
||||||
|
* \param dmats a list of dmatrix handles that will be cached
|
||||||
|
*/
|
||||||
|
SEXP XGBoosterCreate_R(SEXP dmats);
|
||||||
|
/*!
|
||||||
|
* \brief set parameters
|
||||||
|
* \param handle handle
|
||||||
|
* \param name parameter name
|
||||||
|
* \param val value of parameter
|
||||||
|
*/
|
||||||
|
void XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val);
|
||||||
|
/*!
|
||||||
|
* \brief update the model in one round using dtrain
|
||||||
|
* \param handle handle
|
||||||
|
* \param iter current iteration rounds
|
||||||
|
* \param dtrain training data
|
||||||
|
*/
|
||||||
|
void XGBoosterUpdateOneIter_R(SEXP ext, SEXP iter, SEXP dtrain);
|
||||||
|
/*!
|
||||||
|
* \brief get evaluation statistics for xgboost
|
||||||
|
* \param handle handle
|
||||||
|
* \param iter current iteration rounds
|
||||||
|
* \param dmats list of handles to dmatrices
|
||||||
|
* \param evname name of evaluation
|
||||||
|
* \return the string containing evaluation stati
|
||||||
|
*/
|
||||||
|
SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames);
|
||||||
};
|
};
|
||||||
#endif // XGBOOST_WRAPPER_H_
|
#endif // XGBOOST_WRAPPER_H_
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user