workable R wrapper
This commit is contained in:
parent
5e23f6577f
commit
40da2fa2c0
5
Makefile
5
Makefile
@ -11,14 +11,15 @@ endif
|
||||
# specify tensor path
|
||||
BIN = xgboost
|
||||
OBJ =
|
||||
SLIB = python/libxgboostwrapper.so
|
||||
SLIB = python/libxgboostR.so
|
||||
.PHONY: clean all
|
||||
|
||||
all: $(BIN) $(OBJ) $(SLIB)
|
||||
|
||||
xgboost: src/xgboost_main.cpp src/io/io.cpp src/data.h src/tree/*.h src/tree/*.hpp src/gbm/*.h src/gbm/*.hpp src/utils/*.h src/learner/*.h src/learner/*.hpp
|
||||
# now the wrapper takes in two files. io and wrapper part
|
||||
python/libxgboostwrapper.so: python/xgboost_wrapper.cpp src/io/io.cpp src/*.h src/*/*.hpp src/*/*.h
|
||||
#python/libxgboostwrapper.so: python/xgboost_wrapper.cpp src/io/io.cpp src/*.h src/*/*.hpp src/*/*.h
|
||||
python/libxgboostR.so: python/xgboost_wrapper_R.cpp python/xgboost_wrapper.cpp src/io/io.cpp src/*.h src/*/*.hpp src/*/*.h
|
||||
|
||||
$(BIN) :
|
||||
$(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c, $^)
|
||||
|
||||
@ -1,2 +1,82 @@
|
||||
#su
|
||||
dyn.load()
|
||||
# load in library
|
||||
dyn.load("libxgboostR.so")
|
||||
|
||||
# constructing DMatrix
|
||||
xgb.DMatrix <- function(data) {
|
||||
if (typeof(data) == "character") {
|
||||
handle <- .Call("XGDMatrixCreateFromFile_R", data)
|
||||
} else {
|
||||
stop("xgb.DMatrix cannot recognize data type")
|
||||
}
|
||||
return(structure(handle, class="xgb.DMatrix"))
|
||||
}
|
||||
|
||||
# construct a Booster from cachelist
|
||||
xgb.Booster <- function(cachelist, params) {
|
||||
if (typeof(cachelist) != "list") {
|
||||
stop("xgb.Booster: only accepts list of DMatrix as cachelist")
|
||||
}
|
||||
for (dm in cachelist) {
|
||||
if (class(dm) != "xgb.DMatrix") {
|
||||
stop("xgb.Booster: only accepts list of DMatrix as cachelist")
|
||||
}
|
||||
}
|
||||
handle <- .Call("XGBoosterCreate_R", cachelist)
|
||||
.Call("XGBoosterSetParam_R", handle, "silent", "1")
|
||||
for (i in 1:length(params)) {
|
||||
p = params[i]
|
||||
.Call("XGBoosterSetParam_R", handle, names(p), as.character(p))
|
||||
}
|
||||
return(structure(handle, class="xgb.Booster"))
|
||||
}
|
||||
|
||||
# update booster with dtrain
|
||||
xgb.update <- function(booster, dtrain, iter) {
|
||||
if (class(booster) != "xgb.Booster") {
|
||||
stop("xgb.update: first argument must be type xgb.Booster")
|
||||
}
|
||||
if (class(dtrain) != "xgb.DMatrix") {
|
||||
stop("xgb.update: second argument must be type xgb.DMatrix")
|
||||
}
|
||||
.Call("XGBoosterUpdateOneIter_R", booster, as.integer(iter), dtrain)
|
||||
return(TRUE)
|
||||
}
|
||||
# evaluate one iteration
|
||||
xgb.eval <- function(booster, watchlist, iter) {
|
||||
if (class(booster) != "xgb.Booster") {
|
||||
stop("xgb.eval: first argument must be type xgb.Booster")
|
||||
}
|
||||
if (typeof(watchlist) != "list") {
|
||||
stop("xgb.eval: only accepts list of DMatrix as watchlist")
|
||||
}
|
||||
for (w in watchlist) {
|
||||
if (class(w) != "xgb.DMatrix") {
|
||||
stop("xgb.eval: watch list can only contain xgb.DMatrix")
|
||||
}
|
||||
}
|
||||
evnames <- list()
|
||||
for (i in 1:length(watchlist)) {
|
||||
w <- watchlist[i]
|
||||
if (length(names(w)) == 0) {
|
||||
stop("xgb.eval: name tag must be presented for every elements in watchlist")
|
||||
}
|
||||
evnames <- append(evnames, names(w))
|
||||
}
|
||||
msg <- .Call("XGBoosterEvalOneIter_R", booster, as.integer(iter), watchlist, evnames)
|
||||
return(msg)
|
||||
}
|
||||
|
||||
|
||||
# test code here
|
||||
dtrain <- xgb.DMatrix("example/agaricus.txt.train")
|
||||
dtest <- xgb.DMatrix("example/agaricus.txt.test")
|
||||
param <- list("bst:min_child_weight" = 10,
|
||||
"objective" = "binary:logistic"
|
||||
)
|
||||
bst<- xgb.Booster(list(dtrain, dtest), param )
|
||||
success <- xgb.update(bst, dtrain, 0)
|
||||
watchlist <- list('train'=dtrain,'test'=dtest)
|
||||
cat(xgb.eval(bst, watchlist, 0))
|
||||
cat("\n")
|
||||
|
||||
|
||||
|
||||
@ -114,7 +114,7 @@ extern "C" {
|
||||
* \param handle handle
|
||||
* \param iter current iteration rounds
|
||||
* \param dtrain training data
|
||||
*/
|
||||
*/
|
||||
void XGBoosterUpdateOneIter(void *handle, int iter, void *dtrain);
|
||||
/*!
|
||||
* \brief update the model, by directly specify gradient and second order gradient,
|
||||
@ -127,7 +127,7 @@ extern "C" {
|
||||
*/
|
||||
void XGBoosterBoostOneIter(void *handle, void *dtrain,
|
||||
float *grad, float *hess, size_t len);
|
||||
/*!
|
||||
/*!
|
||||
* \brief get evaluation statistics for xgboost
|
||||
* \param handle handle
|
||||
* \param iter current iteration rounds
|
||||
@ -135,7 +135,7 @@ extern "C" {
|
||||
* \param evnames pointers to names of each data
|
||||
* \param len length of dmats
|
||||
* \return the string containing evaluation stati
|
||||
*/
|
||||
*/
|
||||
const char *XGBoosterEvalOneIter(void *handle, int iter, void *dmats[],
|
||||
const char *evnames[], size_t len);
|
||||
/*!
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
#include "xgboost_wrapper_R.h"
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "xgboost_wrapper.h"
|
||||
#include "xgboost_wrapper_R.h"
|
||||
#include "../src/utils/utils.h"
|
||||
using namespace xgboost;
|
||||
|
||||
@ -16,4 +18,48 @@ extern "C" {
|
||||
UNPROTECT(1);
|
||||
return ret;
|
||||
}
|
||||
|
||||
// functions related to booster
|
||||
void _BoosterFinalizer(SEXP ext) {
|
||||
if (R_ExternalPtrAddr(ext) == NULL) return;
|
||||
XGBoosterFree(R_ExternalPtrAddr(ext));
|
||||
R_ClearExternalPtr(ext);
|
||||
}
|
||||
SEXP XGBoosterCreate_R(SEXP dmats) {
|
||||
int len = length(dmats);
|
||||
std::vector<void*> dvec;
|
||||
for (int i = 0; i < len; ++i){
|
||||
dvec.push_back(R_ExternalPtrAddr(VECTOR_ELT(dmats, i)));
|
||||
}
|
||||
void *handle = XGBoosterCreate(&dvec[0], dvec.size());
|
||||
SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
|
||||
R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
|
||||
UNPROTECT(1);
|
||||
return ret;
|
||||
}
|
||||
void XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val) {
|
||||
XGBoosterSetParam(R_ExternalPtrAddr(handle),
|
||||
CHAR(asChar(name)),
|
||||
CHAR(asChar(val)));
|
||||
}
|
||||
void XGBoosterUpdateOneIter_R(SEXP handle, SEXP iter, SEXP dtrain) {
|
||||
XGBoosterUpdateOneIter(R_ExternalPtrAddr(handle),
|
||||
asInteger(iter),
|
||||
R_ExternalPtrAddr(dtrain));
|
||||
}
|
||||
SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames) {
|
||||
utils::Check(length(dmats) == length(evnames), "dmats and evnams must have same length");
|
||||
int len = length(dmats);
|
||||
std::vector<void*> vec_dmats;
|
||||
std::vector<std::string> vec_names;
|
||||
std::vector<const char*> vec_sptr;
|
||||
for (int i = 0; i < len; ++i){
|
||||
vec_dmats.push_back(R_ExternalPtrAddr(VECTOR_ELT(dmats, i)));
|
||||
vec_names.push_back(std::string(CHAR(asChar(VECTOR_ELT(evnames, i)))));
|
||||
vec_sptr.push_back(vec_names.back().c_str());
|
||||
}
|
||||
return mkString(XGBoosterEvalOneIter(R_ExternalPtrAddr(handle),
|
||||
asInteger(iter),
|
||||
&vec_dmats[0], &vec_sptr[0], len));
|
||||
}
|
||||
}
|
||||
|
||||
@ -12,9 +12,37 @@ extern "C" {
|
||||
extern "C" {
|
||||
/*!
|
||||
* \brief load a data matrix
|
||||
* \fname name of the content
|
||||
* \return a loaded data matrix
|
||||
*/
|
||||
SEXP XGDMatrixCreateFromFile_R(SEXP fname);
|
||||
|
||||
/*!
|
||||
* \brief create xgboost learner
|
||||
* \param dmats a list of dmatrix handles that will be cached
|
||||
*/
|
||||
SEXP XGBoosterCreate_R(SEXP dmats);
|
||||
/*!
|
||||
* \brief set parameters
|
||||
* \param handle handle
|
||||
* \param name parameter name
|
||||
* \param val value of parameter
|
||||
*/
|
||||
void XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val);
|
||||
/*!
|
||||
* \brief update the model in one round using dtrain
|
||||
* \param handle handle
|
||||
* \param iter current iteration rounds
|
||||
* \param dtrain training data
|
||||
*/
|
||||
void XGBoosterUpdateOneIter_R(SEXP ext, SEXP iter, SEXP dtrain);
|
||||
/*!
|
||||
* \brief get evaluation statistics for xgboost
|
||||
* \param handle handle
|
||||
* \param iter current iteration rounds
|
||||
* \param dmats list of handles to dmatrices
|
||||
* \param evname name of evaluation
|
||||
* \return the string containing evaluation stati
|
||||
*/
|
||||
SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames);
|
||||
};
|
||||
#endif // XGBOOST_WRAPPER_H_
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user