Merge pull request #34 from tqchen/unity

Unity
This commit is contained in:
Tianqi Chen 2014-08-23 18:56:38 -07:00
commit b2b5895634
23 changed files with 8977 additions and 27 deletions

View File

@ -1,6 +1,9 @@
export CC = gcc
export CXX = g++
export LDFLAGS= -pthread -lm
# note for R module
# add include path to Rinternals.h here
export CPLUS_INCLUDE_PATH=/usr/share/R/include
ifeq ($(no_omp),1)
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -DDISABLE_OPENMP
@ -11,14 +14,16 @@ endif
# specify tensor path
BIN = xgboost
OBJ =
SLIB = python/libxgboostwrapper.so
.PHONY: clean all
SLIB = wrapper/libxgboostwrapper.so wrapper/libxgboostR.so
.PHONY: clean all R
all: $(BIN) $(OBJ) $(SLIB)
all: $(BIN) wrapper/libxgboostwrapper.so
R: wrapper/libxgboostR.so
xgboost: src/xgboost_main.cpp src/io/io.cpp src/data.h src/tree/*.h src/tree/*.hpp src/gbm/*.h src/gbm/*.hpp src/utils/*.h src/learner/*.h src/learner/*.hpp
# now the wrapper takes in two files. io and wrapper part
python/libxgboostwrapper.so: python/xgboost_wrapper.cpp src/io/io.cpp src/*.h src/*/*.hpp src/*/*.h
wrapper/libxgboostwrapper.so: wrapper/xgboost_wrapper.cpp src/io/io.cpp src/*.h src/*/*.hpp src/*/*.h
wrapper/libxgboostR.so: wrapper/xgboost_wrapper.cpp wrapper/xgboost_R.cpp src/io/io.cpp src/*.h src/*/*.hpp src/*/*.h
$(BIN) :
$(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c, $^)

View File

@ -6,7 +6,7 @@ import sys
import numpy as np
# add path of xgboost python module
code_path = os.path.join(
os.path.split(inspect.getfile(inspect.currentframe()))[0], "../../python")
os.path.split(inspect.getfile(inspect.currentframe()))[0], "../../wrapper")
sys.path.append(code_path)

View File

@ -3,7 +3,7 @@
import sys
import numpy as np
# add path of xgboost python module
sys.path.append('../../python/')
sys.path.append('../../wrapper/')
import xgboost as xgb
# path to where the data lies

View File

@ -3,7 +3,7 @@
import sys
import numpy as np
# add path of xgboost python module
sys.path.append('../../python/')
sys.path.append('../../wrapper/')
import xgboost as xgb
from sklearn.ensemble import GradientBoostingClassifier
import time

View File

@ -1,7 +1,7 @@
#! /usr/bin/python
import sys
import numpy as np
sys.path.append('../../python/')
sys.path.append('../../wrapper/')
import xgboost as xgb
# label need to be 0 to num_class -1

View File

@ -1,9 +0,0 @@
python wrapper for xgboost using ctypes
see example for usage
to make the python module, type make in the root directory of project
Graphlab-Create Version
=====
Graphlab Create

View File

@ -43,6 +43,7 @@ inline IEvaluator* CreateEvaluator(const char *name) {
if (!strncmp(name, "ams@", 4)) return new EvalAMS(name);
if (!strncmp(name, "pre@", 4)) return new EvalPrecision(name);
if (!strncmp(name, "pratio@", 7)) return new EvalPrecisionRatio(name);
if (!strncmp(name, "apratio@", 8)) return new EvalPrecisionRatio(name);
if (!strncmp(name, "map", 3)) return new EvalMAP(name);
if (!strncmp(name, "ndcg", 3)) return new EvalNDCG(name);
utils::Error("unknown evaluation metric type: %s", name);

125
wrapper/R-example/demo.R Normal file
View File

@ -0,0 +1,125 @@
# include xgboost library, must set chdir=TRURE
source("../xgboost.R", chdir=TRUE)
# helper function to read libsvm format
# this is very badly written, load in dense, and convert to sparse
# use this only for demo purpose
read.libsvm <- function(fname, maxcol) {
content <- readLines(fname)
nline <- length(content)
label <- numeric(nline)
mat <- matrix(0, nline, maxcol+1)
for (i in 1:nline) {
arr <- as.vector(strsplit(content[i], " ")[[1]])
label[i] <- as.numeric(arr[[1]])
for (j in 2:length(arr)) {
kv <- strsplit(arr[j], ":")[[1]]
# to avoid 0 index
findex <- as.integer(kv[1]) + 1
fvalue <- as.numeric(kv[2])
mat[i,findex] <- fvalue
}
}
mat <- as(mat, "sparseMatrix")
return(list(label=label, data=mat))
}
# test code here
dtrain <- xgb.DMatrix("agaricus.txt.train")
dtest <- xgb.DMatrix("agaricus.txt.test")
param = list("bst:max_depth"=2, "bst:eta"=1, "silent"=1, "objective"="binary:logistic")
watchlist <- list("eval"=dtest,"train"=dtrain)
# training xgboost model
bst <- xgb.train(param, dtrain, nround=2, watchlist=watchlist)
# make prediction
preds <- xgb.predict(bst, dtest)
labels <- xgb.getinfo(dtest, "label")
err <- as.real(sum(as.integer(preds > 0.5) != labels)) / length(labels)
# print error rate
print(paste("error=",err))
# dump model
xgb.dump(bst, "dump.raw.txt")
# dump model with feature map
xgb.dump(bst, "dump.nice.txt", "featmap.txt")
# save dmatrix into binary buffer
succ <- xgb.save(dtest, "dtest.buffer")
# save model into file
succ <- xgb.save(bst, "xgb.model")
# load model and data in
bst2 <- xgb.Booster(modelfile="xgb.model")
dtest2 <- xgb.DMatrix("dtest.buffer")
preds2 <- xgb.predict(bst2, dtest2)
# assert they are the same
stopifnot(sum(abs(preds2-preds)) == 0)
###
# build dmatrix from sparseMatrix
###
print ('start running example of build DMatrix from R.sparseMatrix')
csc <- read.libsvm("agaricus.txt.train", 126)
label <- csc$label
data <- csc$data
dtrain <- xgb.DMatrix(data, info=list(label=label) )
watchlist <- list("eval"=dtest,"train"=dtrain)
bst <- xgb.train(param, dtrain, nround=2, watchlist=watchlist)
###
# build dmatrix from dense matrix
###
print ('start running example of build DMatrix from R.Matrix')
mat = as.matrix(data)
dtrain <- xgb.DMatrix(mat, info=list(label=label) )
watchlist <- list("eval"=dtest,"train"=dtrain)
bst <- xgb.train(param, dtrain, nround=2, watchlist=watchlist)
###
# advanced: cutomsized loss function
#
print("start running example to used cutomized objective function")
# note: for customized objective function, we leave objective as default
# note: what we are getting is margin value in prediction
# you must know what you are doing
param <- list("bst:max_depth" = 2, "bst:eta" = 1, "silent" =1)
# user define objective function, given prediction, return gradient and second order gradient
# this is loglikelihood loss
logregobj <- function(preds, dtrain) {
labels <- xgb.getinfo(dtrain, "label")
preds <- 1.0 / (1.0 + exp(-preds))
grad <- preds - labels
hess <- preds * (1.0-preds)
return(list(grad=grad, hess=hess))
}
# user defined evaluation function, return a list(metric="metric-name", value="metric-value")
# NOTE: when you do customized loss function, the default prediction value is margin
# this may make buildin evalution metric not function properly
# for example, we are doing logistic loss, the prediction is score before logistic transformation
# the buildin evaluation error assumes input is after logistic transformation
# Take this in mind when you use the customization, and maybe you need write customized evaluation function
evalerror <- function(preds, dtrain) {
labels <- xgb.getinfo(dtrain, "label")
err <- as.real(sum(labels != (preds > 0.0))) / length(labels)
return(list(metric="error", value=err))
}
# training with customized objective, we can also do step by step training
# simply look at xgboost.py"s implementation of train
bst <- xgb.train(param, dtrain, nround=2, watchlist, logregobj, evalerror)
###
# advanced: start from a initial base prediction
#
print ("start running example to start from a initial prediction")
# specify parameters via map, definition are same as c++ version
param = list("bst:max_depth"=2, "bst:eta"=1, "silent"=1, "objective"="binary:logistic")
# train xgboost for 1 round
bst <- xgb.train( param, dtrain, 1, watchlist )
# Note: we need the margin value instead of transformed prediction in set_base_margin
# do predict with output_margin=True, will always give you margin values before logistic transformation
ptrain <- xgb.predict(bst, dtrain, outputmargin=TRUE)
ptest <- xgb.predict(bst, dtest, outputmargin=TRUE)
succ <- xgb.setinfo(dtrain, "base_margin", ptrain)
succ <- xgb.setinfo(dtest, "base_margin", ptest)
print ("this is result of running from initial prediction")
bst <- xgb.train( param, dtrain, 1, watchlist )

15
wrapper/README.md Normal file
View File

@ -0,0 +1,15 @@
Wrapper of XGBoost
=====
This folder provides wrapper of xgboost to other languages
Python
=====
* To make the python module, type ```make``` in the root directory of project
* Refer to the walk through example in [python-example/demo.py](python-example/demo.py)
R
=====
* To make the R wrapper, type ```make R``` in the root directory of project
* R module need Rinternals.h, find the path in your system and add it to CPLUS_INCLUDE_PATH in Makefile
* Refer to the walk through example in [R-example/demo.R](R-example/demo.R)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -30,6 +30,16 @@ bst.dump_model('dump.raw.txt')
# dump model with feature map
bst.dump_model('dump.nice.txt','featmap.txt')
# save dmatrix into binary buffer
dtest.save_binary('dtest.buffer')
bst.save_model('xgb.model')
# load model and data in
bst2 = xgb.Booster(model_file='xgb.model')
dtest2 = xgb.DMatrix('dtest.buffer')
preds2 = bst2.predict(dtest2)
# assert they are the same
assert np.sum(np.abs(preds2-preds)) == 0
###
# build dmatrix from scipy.sparse
print ('start running example of build DMatrix from scipy.sparse')
@ -58,7 +68,7 @@ evallist = [(dtest,'eval'), (dtrain,'train')]
bst = xgb.train( param, dtrain, num_round, evallist )
###
# advanced: cutomsized loss function, set loss_type to 0, so that predict get untransformed score
# advanced: cutomsized loss function
#
print ('start running example to used cutomized objective function')
@ -92,7 +102,6 @@ def evalerror(preds, dtrain):
# simply look at xgboost.py's implementation of train
bst = xgb.train(param, dtrain, num_round, evallist, logregobj, evalerror)
###
# advanced: start from a initial base prediction
#

View File

@ -0,0 +1,126 @@
0 cap-shape=bell i
1 cap-shape=conical i
2 cap-shape=convex i
3 cap-shape=flat i
4 cap-shape=knobbed i
5 cap-shape=sunken i
6 cap-surface=fibrous i
7 cap-surface=grooves i
8 cap-surface=scaly i
9 cap-surface=smooth i
10 cap-color=brown i
11 cap-color=buff i
12 cap-color=cinnamon i
13 cap-color=gray i
14 cap-color=green i
15 cap-color=pink i
16 cap-color=purple i
17 cap-color=red i
18 cap-color=white i
19 cap-color=yellow i
20 bruises?=bruises i
21 bruises?=no i
22 odor=almond i
23 odor=anise i
24 odor=creosote i
25 odor=fishy i
26 odor=foul i
27 odor=musty i
28 odor=none i
29 odor=pungent i
30 odor=spicy i
31 gill-attachment=attached i
32 gill-attachment=descending i
33 gill-attachment=free i
34 gill-attachment=notched i
35 gill-spacing=close i
36 gill-spacing=crowded i
37 gill-spacing=distant i
38 gill-size=broad i
39 gill-size=narrow i
40 gill-color=black i
41 gill-color=brown i
42 gill-color=buff i
43 gill-color=chocolate i
44 gill-color=gray i
45 gill-color=green i
46 gill-color=orange i
47 gill-color=pink i
48 gill-color=purple i
49 gill-color=red i
50 gill-color=white i
51 gill-color=yellow i
52 stalk-shape=enlarging i
53 stalk-shape=tapering i
54 stalk-root=bulbous i
55 stalk-root=club i
56 stalk-root=cup i
57 stalk-root=equal i
58 stalk-root=rhizomorphs i
59 stalk-root=rooted i
60 stalk-root=missing i
61 stalk-surface-above-ring=fibrous i
62 stalk-surface-above-ring=scaly i
63 stalk-surface-above-ring=silky i
64 stalk-surface-above-ring=smooth i
65 stalk-surface-below-ring=fibrous i
66 stalk-surface-below-ring=scaly i
67 stalk-surface-below-ring=silky i
68 stalk-surface-below-ring=smooth i
69 stalk-color-above-ring=brown i
70 stalk-color-above-ring=buff i
71 stalk-color-above-ring=cinnamon i
72 stalk-color-above-ring=gray i
73 stalk-color-above-ring=orange i
74 stalk-color-above-ring=pink i
75 stalk-color-above-ring=red i
76 stalk-color-above-ring=white i
77 stalk-color-above-ring=yellow i
78 stalk-color-below-ring=brown i
79 stalk-color-below-ring=buff i
80 stalk-color-below-ring=cinnamon i
81 stalk-color-below-ring=gray i
82 stalk-color-below-ring=orange i
83 stalk-color-below-ring=pink i
84 stalk-color-below-ring=red i
85 stalk-color-below-ring=white i
86 stalk-color-below-ring=yellow i
87 veil-type=partial i
88 veil-type=universal i
89 veil-color=brown i
90 veil-color=orange i
91 veil-color=white i
92 veil-color=yellow i
93 ring-number=none i
94 ring-number=one i
95 ring-number=two i
96 ring-type=cobwebby i
97 ring-type=evanescent i
98 ring-type=flaring i
99 ring-type=large i
100 ring-type=none i
101 ring-type=pendant i
102 ring-type=sheathing i
103 ring-type=zone i
104 spore-print-color=black i
105 spore-print-color=brown i
106 spore-print-color=buff i
107 spore-print-color=chocolate i
108 spore-print-color=green i
109 spore-print-color=orange i
110 spore-print-color=purple i
111 spore-print-color=white i
112 spore-print-color=yellow i
113 population=abundant i
114 population=clustered i
115 population=numerous i
116 population=scattered i
117 population=several i
118 population=solitary i
119 habitat=grasses i
120 habitat=leaves i
121 habitat=meadows i
122 habitat=paths i
123 habitat=urban i
124 habitat=waste i
125 habitat=woods i

222
wrapper/xgboost.R Normal file
View File

@ -0,0 +1,222 @@
# depends on matrix
succ <- require("Matrix")
if (!succ) {
stop("xgboost depends on Matrix library")
}
# load in library
dyn.load("./libxgboostR.so")
# constructing DMatrix
xgb.DMatrix <- function(data, info=list(), missing=0.0) {
if (typeof(data) == "character") {
handle <- .Call("XGDMatrixCreateFromFile_R", data, as.integer(FALSE))
} else if(is.matrix(data)) {
handle <- .Call("XGDMatrixCreateFromMat_R", data, missing)
} else if(class(data) == "dgCMatrix") {
handle <- .Call("XGDMatrixCreateFromCSC_R", data@p, data@i, data@x)
} else {
stop(paste("xgb.DMatrix: does not support to construct from ", typeof(data)))
}
dmat <- structure(handle, class="xgb.DMatrix")
if (length(info) != 0) {
for (i in 1:length(info)) {
p <- info[i]
xgb.setinfo(dmat, names(p), p[[1]])
}
}
return(dmat)
}
# get information from dmatrix
xgb.getinfo <- function(dmat, name) {
if (typeof(name) != "character") {
stop("xgb.getinfo: name must be character")
}
if (class(dmat) != "xgb.DMatrix") {
stop("xgb.setinfo: first argument dtrain must be xgb.DMatrix");
}
if (name != "label" &&
name != "weight" &&
name != "base_margin" ) {
stop(paste("xgb.getinfo: unknown info name", name))
}
ret <- .Call("XGDMatrixGetInfo_R", dmat, name)
return(ret)
}
# set information into dmatrix, this mutate dmatrix
xgb.setinfo <- function(dmat, name, info) {
if (class(dmat) != "xgb.DMatrix") {
stop("xgb.setinfo: first argument dtrain must be xgb.DMatrix");
}
if (name == "label") {
.Call("XGDMatrixSetInfo_R", dmat, name, as.real(info))
return(TRUE)
}
if (name == "weight") {
.Call("XGDMatrixSetInfo_R", dmat, name, as.real(info))
return(TRUE)
}
if (name == "base_margin") {
.Call("XGDMatrixSetInfo_R", dmat, name, as.real(info))
return(TRUE)
}
if (name == "group") {
.Call("XGDMatrixSetInfo_R", dmat, name, as.integer(info))
return(TRUE)
}
stop(pase("xgb.setinfo: unknown info name", name))
return(FALSE)
}
# construct a Booster from cachelist
xgb.Booster <- function(params = list(), cachelist = list(), modelfile = NULL) {
if (typeof(cachelist) != "list") {
stop("xgb.Booster: only accepts list of DMatrix as cachelist")
}
for (dm in cachelist) {
if (class(dm) != "xgb.DMatrix") {
stop("xgb.Booster: only accepts list of DMatrix as cachelist")
}
}
handle <- .Call("XGBoosterCreate_R", cachelist)
.Call("XGBoosterSetParam_R", handle, "seed", "0")
if (length(params) != 0) {
for (i in 1:length(params)) {
p <- params[i]
.Call("XGBoosterSetParam_R", handle, names(p), as.character(p))
}
}
if (!is.null(modelfile)) {
if (typeof(modelfile) != "character"){
stop("xgb.Booster: modelfile must be character");
}
.Call("XGBoosterLoadModel_R", handle, modelfile)
}
return(structure(handle, class="xgb.Booster"))
}
# train a model using given parameters
xgb.train <- function(params, dtrain, nrounds=10, watchlist=list(), obj=NULL, feval=NULL) {
if (typeof(params) != "list") {
stop("xgb.train: first argument params must be list");
}
if (class(dtrain) != "xgb.DMatrix") {
stop("xgb.train: second argument dtrain must be xgb.DMatrix");
}
bst <- xgb.Booster(params, append(watchlist,dtrain))
for (i in 1:nrounds) {
if (is.null(obj)) {
succ <- xgb.iter.update(bst, dtrain, i-1)
} else {
pred <- xgb.predict(bst, dtrain)
gpair <- obj(pred, dtrain)
succ <- xgb.iter.boost(bst, dtrain, gpair)
}
if (length(watchlist) != 0) {
if (is.null(feval)) {
msg <- xgb.iter.eval(bst, watchlist, i-1)
cat(msg); cat("\n")
} else {
cat("["); cat(i); cat("]");
for (j in 1:length(watchlist)) {
w <- watchlist[j]
if (length(names(w)) == 0) {
stop("xgb.eval: name tag must be presented for every elements in watchlist")
}
ret <- feval(xgb.predict(bst, w[[1]]), w[[1]])
cat("\t"); cat(names(w)); cat("-"); cat(ret$metric);
cat(":"); cat(ret$value)
}
cat("\n")
}
}
}
return(bst)
}
# save model or DMatrix to file
xgb.save <- function(handle, fname) {
if (typeof(fname) != "character") {
stop("xgb.save: fname must be character");
}
if (class(handle) == "xgb.Booster") {
.Call("XGBoosterSaveModel_R", handle, fname);
return(TRUE)
}
if (class(handle) == "xgb.DMatrix") {
.Call("XGDMatrixSaveBinary_R", handle, fname, as.integer(FALSE))
return(TRUE)
}
stop("xgb.save: the input must be either xgb.DMatrix or xgb.Booster")
return(FALSE)
}
# predict
xgb.predict <- function(booster, dmat, outputmargin = FALSE) {
if (class(booster) != "xgb.Booster") {
stop("xgb.predict: first argument must be type xgb.Booster")
}
if (class(dmat) != "xgb.DMatrix") {
stop("xgb.predict: second argument must be type xgb.DMatrix")
}
ret <- .Call("XGBoosterPredict_R", booster, dmat, as.integer(outputmargin))
return(ret)
}
# dump model
xgb.dump <- function(booster, fname, fmap = "") {
if (class(booster) != "xgb.Booster") {
stop("xgb.dump: first argument must be type xgb.Booster")
}
if (typeof(fname) != "character"){
stop("xgb.dump: second argument must be type character")
}
.Call("XGBoosterDumpModel_R", booster, fname, fmap)
return(TRUE)
}
##--------------------------------------
# the following are low level iteratively function, not needed
# if you do not want to use them
#---------------------------------------
# iteratively update booster with dtrain
xgb.iter.update <- function(booster, dtrain, iter) {
if (class(booster) != "xgb.Booster") {
stop("xgb.iter.update: first argument must be type xgb.Booster")
}
if (class(dtrain) != "xgb.DMatrix") {
stop("xgb.iter.update: second argument must be type xgb.DMatrix")
}
.Call("XGBoosterUpdateOneIter_R", booster, as.integer(iter), dtrain)
return(TRUE)
}
# iteratively update booster with customized statistics
xgb.iter.boost <- function(booster, dtrain, gpair) {
if (class(booster) != "xgb.Booster") {
stop("xgb.iter.update: first argument must be type xgb.Booster")
}
if (class(dtrain) != "xgb.DMatrix") {
stop("xgb.iter.update: second argument must be type xgb.DMatrix")
}
.Call("XGBoosterBoostOneIter_R", booster, dtrain, gpair$grad, gpair$hess)
return(TRUE)
}
# iteratively evaluate one iteration
xgb.iter.eval <- function(booster, watchlist, iter) {
if (class(booster) != "xgb.Booster") {
stop("xgb.eval: first argument must be type xgb.Booster")
}
if (typeof(watchlist) != "list") {
stop("xgb.eval: only accepts list of DMatrix as watchlist")
}
for (w in watchlist) {
if (class(w) != "xgb.DMatrix") {
stop("xgb.eval: watch list can only contain xgb.DMatrix")
}
}
evnames <- list()
if (length(watchlist) != 0) {
for (i in 1:length(watchlist)) {
w <- watchlist[i]
if (length(names(w)) == 0) {
stop("xgb.eval: name tag must be presented for every elements in watchlist")
}
evnames <- append(evnames, names(w))
}
}
msg <- .Call("XGBoosterEvalOneIter_R", booster, as.integer(iter), watchlist, evnames)
return(msg)
}

View File

@ -127,7 +127,7 @@ class DMatrix:
class Booster:
"""learner class """
def __init__(self, params={}, cache=[]):
def __init__(self, params={}, cache=[], model_file = None):
""" constructor, param: """
for d in cache:
assert isinstance(d, DMatrix)
@ -135,6 +135,8 @@ class Booster:
self.handle = ctypes.c_void_p(xglib.XGBoosterCreate(dmats, len(cache)))
self.set_param({'seed':0})
self.set_param(params)
if model_file != None:
self.load_model(model_file)
def __del__(self):
xglib.XGBoosterFree(self.handle)
def set_param(self, params, pv=None):

208
wrapper/xgboost_R.cpp Normal file
View File

@ -0,0 +1,208 @@
#include <vector>
#include <string>
#include <utility>
#include <cstring>
#include "xgboost_wrapper.h"
#include "xgboost_R.h"
#include "../src/utils/utils.h"
#include "../src/utils/omp.h"
#include "../src/utils/matrix_csr.h"
using namespace xgboost;
extern "C" {
void _DMatrixFinalizer(SEXP ext) {
if (R_ExternalPtrAddr(ext) == NULL) return;
XGDMatrixFree(R_ExternalPtrAddr(ext));
R_ClearExternalPtr(ext);
}
SEXP XGDMatrixCreateFromFile_R(SEXP fname, SEXP silent) {
void *handle = XGDMatrixCreateFromFile(CHAR(asChar(fname)), asInteger(silent));
SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
UNPROTECT(1);
return ret;
}
SEXP XGDMatrixCreateFromMat_R(SEXP mat,
SEXP missing) {
SEXP dim = getAttrib(mat, R_DimSymbol);
int nrow = INTEGER(dim)[0];
int ncol = INTEGER(dim)[1];
double *din = REAL(mat);
std::vector<float> data(nrow * ncol);
#pragma omp parallel for schedule(static)
for (int i = 0; i < nrow; ++i) {
for (int j = 0; j < ncol; ++j) {
data[i * ncol +j] = din[i + nrow * j];
}
}
void *handle = XGDMatrixCreateFromMat(&data[0], nrow, ncol, asReal(missing));
SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
UNPROTECT(1);
return ret;
}
SEXP XGDMatrixCreateFromCSC_R(SEXP indptr,
SEXP indices,
SEXP data) {
const int *col_ptr = INTEGER(indptr);
const int *row_index = INTEGER(indices);
const double *col_data = REAL(data);
int ncol = length(indptr) - 1;
int ndata = length(data);
// transform into CSR format
std::vector<size_t> row_ptr;
std::vector< std::pair<unsigned, float> > csr_data;
utils::SparseCSRMBuilder< std::pair<unsigned,float> > builder(row_ptr, csr_data);
builder.InitBudget();
for (int i = 0; i < ncol; ++i) {
for (int j = col_ptr[i]; j < col_ptr[i+1]; ++j) {
builder.AddBudget(row_index[j]);
}
}
builder.InitStorage();
for (int i = 0; i < ncol; ++i) {
for (int j = col_ptr[i]; j < col_ptr[i+1]; ++j) {
builder.PushElem(row_index[j], std::make_pair(i, col_data[j]));
}
}
utils::Assert(csr_data.size() == static_cast<size_t>(ndata), "BUG CreateFromCSC");
std::vector<float> row_data(ndata);
std::vector<unsigned> col_index(ndata);
#pragma omp parallel for schedule(static)
for (int i = 0; i < ndata; ++i) {
col_index[i] = csr_data[i].first;
row_data[i] = csr_data[i].second;
}
void *handle = XGDMatrixCreateFromCSR(&row_ptr[0], &col_index[0], &row_data[0], row_ptr.size(), ndata );
SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _DMatrixFinalizer, TRUE);
UNPROTECT(1);
return ret;
}
void XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent) {
XGDMatrixSaveBinary(R_ExternalPtrAddr(handle),
CHAR(asChar(fname)), asInteger(silent));
}
void XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array) {
int len = length(array);
const char *name = CHAR(asChar(field));
if (!strcmp("group", name)) {
std::vector<unsigned> vec(len);
#pragma omp parallel for schedule(static)
for (int i = 0; i < len; ++i) {
vec[i] = static_cast<unsigned>(INTEGER(array)[i]);
}
XGDMatrixSetGroup(R_ExternalPtrAddr(handle), &vec[0], len);
return;
}
{
std::vector<float> vec(len);
#pragma omp parallel for schedule(static)
for (int i = 0; i < len; ++i) {
vec[i] = REAL(array)[i];
}
XGDMatrixSetFloatInfo(R_ExternalPtrAddr(handle),
CHAR(asChar(field)),
&vec[0], len);
}
}
SEXP XGDMatrixGetInfo_R(SEXP handle, SEXP field) {
size_t olen;
const float *res = XGDMatrixGetFloatInfo(R_ExternalPtrAddr(handle),
CHAR(asChar(field)), &olen);
SEXP ret = PROTECT(allocVector(REALSXP, olen));
for (size_t i = 0; i < olen; ++i) {
REAL(ret)[i] = res[i];
}
UNPROTECT(1);
return ret;
}
// functions related to booster
void _BoosterFinalizer(SEXP ext) {
if (R_ExternalPtrAddr(ext) == NULL) return;
XGBoosterFree(R_ExternalPtrAddr(ext));
R_ClearExternalPtr(ext);
}
SEXP XGBoosterCreate_R(SEXP dmats) {
int len = length(dmats);
std::vector<void*> dvec;
for (int i = 0; i < len; ++i){
dvec.push_back(R_ExternalPtrAddr(VECTOR_ELT(dmats, i)));
}
void *handle = XGBoosterCreate(&dvec[0], dvec.size());
SEXP ret = PROTECT(R_MakeExternalPtr(handle, R_NilValue, R_NilValue));
R_RegisterCFinalizerEx(ret, _BoosterFinalizer, TRUE);
UNPROTECT(1);
return ret;
}
void XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val) {
XGBoosterSetParam(R_ExternalPtrAddr(handle),
CHAR(asChar(name)),
CHAR(asChar(val)));
}
void XGBoosterUpdateOneIter_R(SEXP handle, SEXP iter, SEXP dtrain) {
XGBoosterUpdateOneIter(R_ExternalPtrAddr(handle),
asInteger(iter),
R_ExternalPtrAddr(dtrain));
}
void XGBoosterBoostOneIter_R(SEXP handle, SEXP dtrain, SEXP grad, SEXP hess) {
utils::Check(length(grad) == length(hess), "gradient and hess must have same length");
int len = length(grad);
std::vector<float> tgrad(len), thess(len);
#pragma omp parallel for schedule(static)
for (int j = 0; j < len; ++j) {
tgrad[j] = REAL(grad)[j];
thess[j] = REAL(hess)[j];
}
XGBoosterBoostOneIter(R_ExternalPtrAddr(handle),
R_ExternalPtrAddr(dtrain),
&tgrad[0], &thess[0], len);
}
SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames) {
utils::Check(length(dmats) == length(evnames), "dmats and evnams must have same length");
int len = length(dmats);
std::vector<void*> vec_dmats;
std::vector<std::string> vec_names;
std::vector<const char*> vec_sptr;
for (int i = 0; i < len; ++i){
vec_dmats.push_back(R_ExternalPtrAddr(VECTOR_ELT(dmats, i)));
vec_names.push_back(std::string(CHAR(asChar(VECTOR_ELT(evnames, i)))));
vec_sptr.push_back(vec_names.back().c_str());
}
return mkString(XGBoosterEvalOneIter(R_ExternalPtrAddr(handle),
asInteger(iter),
&vec_dmats[0], &vec_sptr[0], len));
}
SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP output_margin) {
size_t olen;
const float *res = XGBoosterPredict(R_ExternalPtrAddr(handle),
R_ExternalPtrAddr(dmat),
asInteger(output_margin),
&olen);
SEXP ret = PROTECT(allocVector(REALSXP, olen));
for (size_t i = 0; i < olen; ++i) {
REAL(ret)[i] = res[i];
}
UNPROTECT(1);
return ret;
}
void XGBoosterLoadModel_R(SEXP handle, SEXP fname) {
XGBoosterLoadModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname)));
}
void XGBoosterSaveModel_R(SEXP handle, SEXP fname) {
XGBoosterSaveModel(R_ExternalPtrAddr(handle), CHAR(asChar(fname)));
}
void XGBoosterDumpModel_R(SEXP handle, SEXP fname, SEXP fmap) {
size_t olen;
const char **res = XGBoosterDumpModel(R_ExternalPtrAddr(handle),
CHAR(asChar(fmap)),
&olen);
FILE *fo = utils::FopenCheck(CHAR(asChar(fname)), "w");
for (size_t i = 0; i < olen; ++i) {
fprintf(fo, "booster[%lu]:\n", i);
fprintf(fo, "%s", res[i]);
}
fclose(fo);
}
}

124
wrapper/xgboost_R.h Normal file
View File

@ -0,0 +1,124 @@
#ifndef XGBOOST_WRAPPER_R_H_
#define XGBOOST_WRAPPER_R_H_
/*!
* \file xgboost_wrapper_R.h
* \author Tianqi Chen
* \brief R wrapper of xgboost
*/
extern "C" {
#include <Rinternals.h>
}
extern "C" {
/*!
* \brief load a data matrix
* \param fname name of the content
* \param silent whether print messages
* \return a loaded data matrix
*/
SEXP XGDMatrixCreateFromFile_R(SEXP fname, SEXP silent);
/*!
* \brief create matrix content from dense matrix
* This assumes the matrix is stored in column major format
* \param data R Matrix object
* \param missing which value to represent missing value
* \return created dmatrix
*/
SEXP XGDMatrixCreateFromMat_R(SEXP mat,
SEXP missing);
/*!
* \brief create a matrix content from CSC format
* \param indptr pointer to column headers
* \param indices row indices
* \param data content of the data
* \return created dmatrix
*/
SEXP XGDMatrixCreateFromCSC_R(SEXP indptr,
SEXP indices,
SEXP data);
/*!
* \brief load a data matrix into binary file
* \param handle a instance of data matrix
* \param fname file name
* \param silent print statistics when saving
*/
void XGDMatrixSaveBinary_R(SEXP handle, SEXP fname, SEXP silent);
/*!
* \brief set information to dmatrix
* \param handle a instance of data matrix
* \param field field name, can be label, weight
* \param array pointer to float vector
*/
void XGDMatrixSetInfo_R(SEXP handle, SEXP field, SEXP array);
/*!
* \brief get info vector from matrix
* \param handle a instance of data matrix
* \param field field name
* \return info vector
*/
SEXP XGDMatrixGetInfo_R(SEXP handle, SEXP field);
/*!
* \brief create xgboost learner
* \param dmats a list of dmatrix handles that will be cached
*/
SEXP XGBoosterCreate_R(SEXP dmats);
/*!
* \brief set parameters
* \param handle handle
* \param name parameter name
* \param val value of parameter
*/
void XGBoosterSetParam_R(SEXP handle, SEXP name, SEXP val);
/*!
* \brief update the model in one round using dtrain
* \param handle handle
* \param iter current iteration rounds
* \param dtrain training data
*/
void XGBoosterUpdateOneIter_R(SEXP ext, SEXP iter, SEXP dtrain);
/*!
* \brief update the model, by directly specify gradient and second order gradient,
* this can be used to replace UpdateOneIter, to support customized loss function
* \param handle handle
* \param dtrain training data
* \param grad gradient statistics
* \param hess second order gradient statistics
*/
void XGBoosterBoostOneIter_R(SEXP handle, SEXP dtrain, SEXP grad, SEXP hess);
/*!
* \brief get evaluation statistics for xgboost
* \param handle handle
* \param iter current iteration rounds
* \param dmats list of handles to dmatrices
* \param evname name of evaluation
* \return the string containing evaluation stati
*/
SEXP XGBoosterEvalOneIter_R(SEXP handle, SEXP iter, SEXP dmats, SEXP evnames);
/*!
* \brief make prediction based on dmat
* \param handle handle
* \param dmat data matrix
* \param output_margin whether only output raw margin value
*/
SEXP XGBoosterPredict_R(SEXP handle, SEXP dmat, SEXP output_margin);
/*!
* \brief load model from existing file
* \param handle handle
* \param fname file name
*/
void XGBoosterLoadModel_R(SEXP handle, SEXP fname);
/*!
* \brief save model into existing file
* \param handle handle
* \param fname file name
*/
void XGBoosterSaveModel_R(SEXP handle, SEXP fname);
/*!
* \brief dump model into text file
* \param handle handle
* \param fname file name of model that can be dumped into
* \param fmap name to fmap can be empty string
*/
void XGBoosterDumpModel_R(SEXP handle, SEXP fname, SEXP fmap);
};
#endif // XGBOOST_WRAPPER_R_H_

View File

@ -16,7 +16,6 @@ extern "C" {
void* XGDMatrixCreateFromFile(const char *fname, int silent);
/*!
* \brief create a matrix content from csr format
* \param handle a instance of data matrix
* \param indptr pointer to row headers
* \param indices findex
* \param data fvalue
@ -31,7 +30,6 @@ extern "C" {
size_t nelem);
/*!
* \brief create matrix content from dense matrix
* \param handle a instance of data matrix
* \param data pointer to the data space
* \param nrow number of rows
* \param ncol number columns
@ -81,8 +79,8 @@ extern "C" {
/*!
* \brief get float info vector from matrix
* \param handle a instance of data matrix
* \param len used to set result length
* \param field field name
* \param out_len used to set result length
* \return pointer to the label
*/
const float* XGDMatrixGetFloatInfo(const void *handle, const char *field, size_t* out_len);
@ -114,7 +112,7 @@ extern "C" {
* \param handle handle
* \param iter current iteration rounds
* \param dtrain training data
*/
*/
void XGBoosterUpdateOneIter(void *handle, int iter, void *dtrain);
/*!
* \brief update the model, by directly specify gradient and second order gradient,
@ -127,7 +125,7 @@ extern "C" {
*/
void XGBoosterBoostOneIter(void *handle, void *dtrain,
float *grad, float *hess, size_t len);
/*!
/*!
* \brief get evaluation statistics for xgboost
* \param handle handle
* \param iter current iteration rounds
@ -135,7 +133,7 @@ extern "C" {
* \param evnames pointers to names of each data
* \param len length of dmats
* \return the string containing evaluation stati
*/
*/
const char *XGBoosterEvalOneIter(void *handle, int iter, void *dmats[],
const char *evnames[], size_t len);
/*!
@ -165,7 +163,7 @@ extern "C" {
* \param out_len length of output array
* \return char *data[], representing dump of each model
*/
const char** XGBoosterDumpModel(void *handle, const char *fmap,
const char **XGBoosterDumpModel(void *handle, const char *fmap,
size_t *out_len);
};
#endif // XGBOOST_WRAPPER_H_