add java wrapper

This commit is contained in:
yanqingmen 2015-06-09 23:14:50 -07:00
parent fcca359774
commit f91a098770
37 changed files with 3545 additions and 1 deletions

8
.gitignore vendored
View File

@ -58,3 +58,11 @@ R-package.Rproj
*.cache* *.cache*
R-package/inst R-package/inst
R-package/src R-package/src
#java
java/xgboost4j/target
java/xgboost4j/tmp
java/xgboost4j-demo/target
java/xgboost4j-demo/data/
java/xgboost4j-demo/tmp/
java/xgboost4j-demo/model/
nb-configuration*

View File

@ -3,6 +3,8 @@ export CXX = g++
export MPICXX = mpicxx export MPICXX = mpicxx
export LDFLAGS= -pthread -lm export LDFLAGS= -pthread -lm
export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -funroll-loops export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -funroll-loops
# java include path
export JAVAINCFLAGS = -I${JAVA_HOME}/include -I${JAVA_HOME}/include/linux -I./java
ifeq ($(OS), Windows_NT) ifeq ($(OS), Windows_NT)
export CXX = g++ -m64 export CXX = g++ -m64
@ -53,6 +55,9 @@ else
SLIB = wrapper/libxgboostwrapper.so SLIB = wrapper/libxgboostwrapper.so
endif endif
# java lib
JLIB = java/libxgboostjavawrapper.so
# specify tensor path # specify tensor path
BIN = xgboost BIN = xgboost
MOCKBIN = xgboost.mock MOCKBIN = xgboost.mock
@ -79,6 +84,9 @@ main.o: src/xgboost_main.cpp src/utils/*.h src/*.h src/learner/*.hpp src/learner
xgboost: updater.o gbm.o io.o main.o $(LIBRABIT) $(LIBDMLC) xgboost: updater.o gbm.o io.o main.o $(LIBRABIT) $(LIBDMLC)
wrapper/xgboost_wrapper.dll wrapper/libxgboostwrapper.so: wrapper/xgboost_wrapper.cpp src/utils/*.h src/*.h src/learner/*.hpp src/learner/*.h updater.o gbm.o io.o $(LIBRABIT) $(LIBDMLC) wrapper/xgboost_wrapper.dll wrapper/libxgboostwrapper.so: wrapper/xgboost_wrapper.cpp src/utils/*.h src/*.h src/learner/*.hpp src/learner/*.h updater.o gbm.o io.o $(LIBRABIT) $(LIBDMLC)
java: java/libxgboostjavawrapper.so
java/libxgboostjavawrapper.so: java/xgboost4j_wrapper.cpp wrapper/xgboost_wrapper.cpp src/utils/*.h src/*.h src/learner/*.hpp src/learner/*.h updater.o gbm.o io.o $(LIBRABIT) $(LIBDMLC)
# dependency on rabit # dependency on rabit
subtree/rabit/lib/librabit.a: subtree/rabit/src/engine.cc subtree/rabit/lib/librabit.a: subtree/rabit/src/engine.cc
+ cd subtree/rabit;make lib/librabit.a; cd ../.. + cd subtree/rabit;make lib/librabit.a; cd ../..
@ -98,6 +106,9 @@ $(MOCKBIN) :
$(SLIB) : $(SLIB) :
$(CXX) $(CFLAGS) -fPIC -shared -o $@ $(filter %.cpp %.o %.c %.a %.cc, $^) $(LDFLAGS) $(DLLFLAGS) $(CXX) $(CFLAGS) -fPIC -shared -o $@ $(filter %.cpp %.o %.c %.a %.cc, $^) $(LDFLAGS) $(DLLFLAGS)
$(JLIB) :
$(CXX) $(CFLAGS) -fPIC -shared -o $@ $(filter %.so %.cpp %.o %.c %.a %.cc, $^) $(LDFLAGS) $(JAVAINCFLAGS)
$(OBJ) : $(OBJ) :
$(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) ) $(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) )

28
java/README.md Normal file
View File

@ -0,0 +1,28 @@
# xgboost4j
this is a java wrapper for xgboost
the structure of this wrapper is almost the same as the official python wrapper.
core of this wrapper is two classes:
* DMatrix: for handling data
* Booster: for train and predict
## usage:
please refer to [xgboost4j.md](doc/xgboost4j.md) for more information.
besides, simple examples could be found in [xgboost4j-demo](xgboost4j-demo/README.md)
## build native library
for windows: open the xgboost.sln in windows folder, you will found the xgboostjavawrapper project, you should do the following steps to build wrapper library:
* Select x64/win32 and Release in build
* (if you have setted `JAVA_HOME` properly in windows environment variables, escape this step) right click on xgboostjavawrapper project -> choose "Properties" -> click on "C/C++" in the window -> change the "Additional Include Directories" to fit your jdk install path.
* rebuild all
* move the dll "xgboostjavawrapper.dll" to "xgboost4j/src/main/resources/lib/"(you may need to create this folder if necessary.)
for linux:
* make sure you have installed jdk and `JAVA_HOME` has been setted properly
* run "create_wrap.sh"

15
java/create_wrap.sh Executable file
View File

@ -0,0 +1,15 @@
echo "build java wrapper"
cd ..
make java
cd java
echo "move native lib"
libPath="xgboost4j/src/main/resources/lib"
if [ ! -d "$libPath" ]; then
mkdir "$libPath"
fi
rm -f xgboost4j/src/main/resources/lib/libxgboostjavawrapper.so
mv libxgboostjavawrapper.so xgboost4j/src/main/resources/lib/
echo "complete"

157
java/doc/xgboost4j.md Normal file
View File

@ -0,0 +1,157 @@
xgboost4j : java wrapper for xgboost
====
This page will introduce xgboost4j, the java wrapper for xgboost, including:
* [Building](#build-xgboost4j)
* [Data Interface](#data-interface)
* [Setting Parameters](#setting-parameters)
* [Train Model](#training-model)
* [Prediction](#prediction)
=
#### Build xgboost4j
* Build native library
first make sure you have installed jdk and `JAVA_HOME` has been setted properly, then simply run `./create_wrap.sh`.
* Package xgboost4j
to package xgboost4j, you can run `mvn package` in xgboost4j folder or just use IDE(eclipse/netbeans) to open this maven project and build.
=
#### Data Interface
Like the xgboost python module, xgboost4j use ```DMatrix``` to handle data, libsvm txt format file, sparse matrix in CSR/CSC format, and dense matrix is supported.
* To import ```DMatrix``` :
```java
import org.dmlc.xgboost4j.DMatrix;
```
* To load libsvm text format file, the usage is like :
```java
DMatrix dmat = new DMatrix("train.svm.txt");
```
* To load sparse matrix in CSR/CSC format is a little complicated, the usage is like :
suppose a sparse matrix :
1 0 2 0
4 0 0 3
3 1 2 0
for CSR format
```java
long[] rowHeaders = new long[] {0,2,4,7};
float[] data = new float[] {1f,2f,4f,3f,3f,1f,2f};
int[] colIndex = new int[] {0,2,0,3,0,1,2};
DMatrix dmat = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR);
```
for CSC format
```java
long[] colHeaders = new long[] {0,3,4,6,7};
float[] data = new float[] {1f,4f,3f,1f,2f,2f,3f};
int[] rowIndex = new int[] {0,1,2,2,0,2,1};
DMatrix dmat = new DMatrix(colHeaders, rowIndex, data, DMatrix.SparseType.CSC);
```
* To load 3*2 dense matrix, the usage is like :
suppose a matrix :
1 2
3 4
5 6
```java
float[] data = new float[] {1f,2f,3f,4f,5f,6f};
int nrow = 3;
int ncol = 2;
float missing = 0.0f;
DMatrix dmat = new Matrix(data, nrow, ncol, missing);
```
* To set weight :
```java
float[] weights = new float[] {1f,2f,1f};
dmat.setWeight(weights);
```
#### Setting Parameters
* A util class ```Params``` in xgboost4j is used to handle parameters.
* To import ```Params``` :
```java
import org.dmlc.xgboost4j.util.Params;
```
* to set parameters :
```java
Params params = new Params() {
{
put("eta", "1.0");
put("max_depth", "2");
put("silent", "1");
put("objective", "binary:logistic");
put("eval_metric", "logloss");
}
};
```
* Multiple values with same param key is handled naturally in ```Params```, e.g. :
```java
Params params = new Params() {
{
put("eta", "1.0");
put("max_depth", "2");
put("silent", "1");
put("objective", "binary:logistic");
put("eval_metric", "logloss");
put("eval_metric", "error");
}
};
```
#### Training Model
With parameters and data, you are able to train a booster model.
* Import ```Trainer``` and ```Booster``` :
```java
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.util.Trainer;
```
* Training
```java
DMatrix trainMat = new DMatrix("train.svm.txt");
DMatrix validMat = new DMatrix("valid.svm.txt");
DMatrix[] evalMats = new DMatrix[] {trainMat, validMat};
String[] evalNames = new String[] {"train", "valid"};
int round = 2;
Booster booster = Trainer.train(params, trainMat, round, evalMats, evalNames, null, null);
```
* Saving model
After training, you can save model and dump it out.
```java
booster.saveModel("model.bin");
```
* Dump Model and Feature Map
```java
booster.dumpModel("modelInfo.txt", false)
//dump with featureMap
booster.dumpModel("modelInfo.txt", "featureMap.txt", false)
```
* Load a model
```java
Params param = new Params() {
{
put("silent", "1");
put("nthread", "6");
}
};
Booster booster = new Booster(param, "model.bin");
```
####Prediction
after training and loading a model, you use it to predict other data, the predict results will be a two-dimension float array (nsample, nclass) ,for predict leaf, it would be (nsample, nclass*ntrees)
```java
DMatrix dtest = new DMatrix("test.svm.txt");
//predict
float[][] predicts = booster.predict(dtest);
//predict leaf
float[][] leafPredicts = booster.predict(dtest, 0, true);
```

View File

@ -0,0 +1,15 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

View File

@ -0,0 +1,10 @@
xgboost4j examples
====
* [Basic walkthrough of wrappers](src/main/java/org/dmlc/xgboost4j/demo/BasicWalkThrough.java)
* [Cutomize loss function, and evaluation metric](src/main/java/org/dmlc/xgboost4j/demo/CustomObjective.java)
* [Boosting from existing prediction](src/main/java/org/dmlc/xgboost4j/demo/BoostFromPrediction.java)
* [Predicting using first n trees](src/main/java/org/dmlc/xgboost4j/demo/PredictFirstNtree.java)
* [Generalized Linear Model](src/main/java/org/dmlc/xgboost4j/demo/GeneralizedLinearModel.java)
* [Cross validation](src/main/java/org/dmlc/xgboost4j/demo/CrossValidation.java)
* [Predicting leaf indices](src/main/java/org/dmlc/xgboost4j/demo/PredictLeafIndices.java)
* [External Memory](src/main/java/org/dmlc/xgboost4j/demo/ExternalMemory.java)

View File

@ -0,0 +1,36 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.dmlc</groupId>
<artifactId>xgboost4j-demo</artifactId>
<version>1.0</version>
<packaging>jar</packaging>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>1.7</maven.compiler.source>
<maven.compiler.target>1.7</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>org.dmlc</groupId>
<artifactId>xgboost4j</artifactId>
<version>1.1</version>
</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.4</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.4</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.11</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>

View File

@ -0,0 +1,117 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import java.io.File;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.util.Arrays;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.demo.util.DataLoader;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
/**
* a simple example of java wrapper for xgboost
* @author hzx
*/
public class BasicWalkThrough {
public static boolean checkPredicts(float[][] fPredicts, float[][] sPredicts) {
if(fPredicts.length != sPredicts.length) {
return false;
}
for(int i=0; i<fPredicts.length; i++) {
if(!Arrays.equals(fPredicts[i], sPredicts[i])) {
return false;
}
}
return true;
}
public static void main(String[] args) throws UnsupportedEncodingException, IOException {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//specify parameters
Params param = new Params() {
{
put("eta", "1.0");
put("max_depth", "2");
put("silent", "1");
put("objective", "binary:logistic");
}
};
//specify evaluate datasets and evaluate names
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
String[] evalNames = new String[] {"train", "test"};
//set round
int round = 2;
//train a boost model
Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, null, null);
//predict
float[][] predicts = booster.predict(testMat);
//save model to modelPath
File file = new File("./model");
if(!file.exists()) {
file.mkdirs();
}
String modelPath = "./model/xgb.model";
booster.saveModel(modelPath);
//dump model
booster.dumpModel("./model/dump.raw.txt", false);
//dump model with feature map
booster.dumpModel("./model/dump.nice.txt", "../../demo/data/featmap.txt", false);
//save dmatrix into binary buffer
testMat.saveBinary("./model/dtest.buffer");
//reload model and data
Booster booster2 = new Booster(param, "./model/xgb.model");
DMatrix testMat2 = new DMatrix("./model/dtest.buffer");
float[][] predicts2 = booster2.predict(testMat2);
//check the two predicts
System.out.println(checkPredicts(predicts, predicts2));
System.out.println("start build dmatrix from csr sparse data ...");
//build dmatrix from CSR Sparse Matrix
DataLoader.CSRSparseData spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train");
DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data, DMatrix.SparseType.CSR);
trainMat2.setLabel(spData.labels);
dmats = new DMatrix[] {trainMat2, testMat};
Booster booster3 = Trainer.train(param, trainMat2, round, dmats, evalNames, null, null);
float[][] predicts3 = booster3.predict(testMat2);
//check predicts
System.out.println(checkPredicts(predicts, predicts3));
}
}

View File

@ -0,0 +1,61 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
/**
* example for start from a initial base prediction
* @author hzx
*/
public class BoostFromPrediction {
public static void main(String[] args) {
System.out.println("start running example to start from a initial prediction");
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//specify parameters
Params param = new Params() {
{
put("eta", "1.0");
put("max_depth", "2");
put("silent", "1");
put("objective", "binary:logistic");
}
};
//specify evaluate datasets and evaluate names
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
String[] evalNames = new String[] {"train", "test"};
//train xgboost for 1 round
Booster booster = Trainer.train(param, trainMat, 1, dmats, evalNames, null, null);
float[][] trainPred = booster.predict(trainMat, true);
float[][] testPred = booster.predict(testMat, true);
trainMat.setBaseMargin(trainPred);
testMat.setBaseMargin(testPred);
System.out.println("result of running from initial prediction");
Booster booster2 = Trainer.train(param, trainMat, 1, dmats, evalNames, null, null);
}
}

View File

@ -0,0 +1,53 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import java.io.IOException;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.Params;
/**
* an example of cross validation
* @author hzx
*/
public class CrossValidation {
public static void main(String[] args) throws IOException {
//load train mat
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
//set params
Params param = new Params() {
{
put("eta", "1.0");
put("max_depth", "3");
put("silent", "1");
put("nthread", "6");
put("objective", "binary:logistic");
put("gamma", "1.0");
put("eval_metric", "error");
}
};
//do 5-fold cross validation
int round = 2;
int nfold = 5;
//set additional eval_metrics
String[] metrics = null;
String[] evalHist = Trainer.crossValiation(param, trainMat, round, nfold, metrics, null, null);
}
}

View File

@ -0,0 +1,154 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.IEvaluation;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.IObjective;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
/**
* an example user define objective and eval
* NOTE: when you do customized loss function, the default prediction value is margin
* this may make buildin evalution metric not function properly
* for example, we are doing logistic loss, the prediction is score before logistic transformation
* he buildin evaluation error assumes input is after logistic transformation
* Take this in mind when you use the customization, and maybe you need write customized evaluation function
* @author hzx
*/
public class CustomObjective {
/**
* loglikelihoode loss obj function
*/
public static class LogRegObj implements IObjective {
/**
* simple sigmoid func
* @param input
* @return
* Note: this func is not concern about numerical stability, only used as example
*/
public float sigmoid(float input) {
float val = (float) (1/(1+Math.exp(-input)));
return val;
}
public float[][] transform(float[][] predicts) {
int nrow = predicts.length;
float[][] transPredicts = new float[nrow][1];
for(int i=0; i<nrow; i++) {
transPredicts[i][0] = sigmoid(predicts[i][0]);
}
return transPredicts;
}
@Override
public List<float[]> getGradient(float[][] predicts, DMatrix dtrain) {
int nrow = predicts.length;
List<float[]> gradients = new ArrayList<>();
float[] labels = dtrain.getLabel();
float[] grad = new float[nrow];
float[] hess = new float[nrow];
float[][] transPredicts = transform(predicts);
for(int i=0; i<nrow; i++) {
float predict = transPredicts[i][0];
grad[i] = predict - labels[i];
hess[i] = predict * (1 - predict);
}
gradients.add(grad);
gradients.add(hess);
return gradients;
}
}
/**
* user defined eval function.
* NOTE: when you do customized loss function, the default prediction value is margin
* this may make buildin evalution metric not function properly
* for example, we are doing logistic loss, the prediction is score before logistic transformation
* the buildin evaluation error assumes input is after logistic transformation
* Take this in mind when you use the customization, and maybe you need write customized evaluation function
*/
public static class EvalError implements IEvaluation {
String evalMetric = "custom_error";
public EvalError() {
}
@Override
public String getMetric() {
return evalMetric;
}
@Override
public float eval(float[][] predicts, DMatrix dmat) {
float error = 0f;
float[] labels = dmat.getLabel();
int nrow = predicts.length;
for(int i=0; i<nrow; i++) {
if(labels[i]==0f && predicts[i][0]>0) {
error++;
}
else if(labels[i]==1f && predicts[i][0]<=0) {
error++;
}
}
return error/labels.length;
}
}
public static void main(String[] args) {
//load train mat (svmlight format)
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
//load valid mat (svmlight format)
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//set params
//set params
Params param = new Params() {
{
put("eta", "1.0");
put("max_depth", "2");
put("silent", "1");
}
};
//set round
int round = 2;
//set evaluation data
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
String[] evalNames = new String[] {"train", "eval"};
//user define obj and eval
IObjective obj = new LogRegObj();
IEvaluation eval = new EvalError();
//train a booster
System.out.println("begin to train the booster model");
Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, obj, eval);
}
}

View File

@ -0,0 +1,59 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
/**
* simple example for using external memory version
* @author hzx
*/
public class ExternalMemory {
public static void main(String[] args) {
//this is the only difference, add a # followed by a cache prefix name
//several cache file with the prefix will be generated
//currently only support convert from libsvm file
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache");
//specify parameters
Params param = new Params() {
{
put("eta", "1.0");
put("max_depth", "2");
put("silent", "1");
put("objective", "binary:logistic");
}
};
//performance notice: set nthread to be the number of your real cpu
//some cpu offer two threads per core, for example, a 4 core cpu with 8 threads, in such case set nthread=4
//param.put("nthread", "num_real_cpu");
//specify evaluate datasets and evaluate names
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
String[] evalNames = new String[] {"train", "test"};
//set round
int round = 2;
//train a boost model
Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, null, null);
}
}

View File

@ -0,0 +1,68 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.demo.util.CustomEval;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
/**
* this is an example of fit generalized linear model in xgboost
* basically, we are using linear model, instead of tree for our boosters
* @author hzx
*/
public class GeneralizedLinearModel {
public static void main(String[] args) {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//specify parameters
//change booster to gblinear, so that we are fitting a linear model
// alpha is the L1 regularizer
//lambda is the L2 regularizer
//you can also set lambda_bias which is L2 regularizer on the bias term
Params param = new Params() {
{
put("alpha", "0.0001");
put("silent", "1");
put("objective", "binary:logistic");
put("booster", "gblinear");
}
};
//normally, you do not need to set eta (step_size)
//XGBoost uses a parallel coordinate descent algorithm (shotgun),
//there could be affection on convergence with parallelization on certain cases
//setting eta to be smaller value, e.g 0.5 can make the optimization more stable
//param.put("eta", "0.5");
//specify evaluate datasets and evaluate names
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
String[] evalNames = new String[] {"train", "test"};
//train a booster
int round = 4;
Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, null, null);
float[][] predicts = booster.predict(testMat);
CustomEval eval = new CustomEval();
System.out.println("error=" + eval.eval(predicts, testMat));
}
}

View File

@ -0,0 +1,63 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.demo.util.CustomEval;
/**
* predict first ntree
* @author hzx
*/
public class PredictFirstNtree {
public static void main(String[] args) {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//specify parameters
Params param = new Params() {
{
put("eta", "1.0");
put("max_depth", "2");
put("silent", "1");
put("objective", "binary:logistic");
}
};
//specify evaluate datasets and evaluate names
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
String[] evalNames = new String[] {"train", "test"};
//train a booster
int round = 3;
Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, null, null);
//predict use 1 tree
float[][] predicts1 = booster.predict(testMat, false, 1);
//by default all trees are used to do predict
float[][] predicts2 = booster.predict(testMat);
//use a simple evaluation class to check error result
CustomEval eval = new CustomEval();
System.out.println("error of predicts1: " + eval.eval(predicts1, testMat));
System.out.println("error of predicts2: " + eval.eval(predicts2, testMat));
}
}

View File

@ -0,0 +1,64 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo;
import java.util.Arrays;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
/**
* predict leaf indices
* @author hzx
*/
public class PredictLeafIndices {
public static void main(String[] args) {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
//specify parameters
Params param = new Params() {
{
put("eta", "1.0");
put("max_depth", "2");
put("silent", "1");
put("objective", "binary:logistic");
}
};
//specify evaluate datasets and evaluate names
DMatrix[] dmats = new DMatrix[] {trainMat, testMat};
String[] evalNames = new String[] {"train", "test"};
//train a booster
int round = 3;
Booster booster = Trainer.train(param, trainMat, round, dmats, evalNames, null, null);
//predict using first 2 tree
float[][] leafindex = booster.predict(testMat, 2, true);
for(float[] leafs : leafindex) {
System.out.println(Arrays.toString(leafs));
}
//predict all trees
leafindex = booster.predict(testMat, 0, true);
for(float[] leafs : leafindex) {
System.out.println(Arrays.toString(leafs));
}
}
}

View File

@ -0,0 +1,50 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo.util;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.IEvaluation;
/**
* a util evaluation class for examples
* @author hzx
*/
public class CustomEval implements IEvaluation {
String evalMetric = "custom_error";
@Override
public String getMetric() {
return evalMetric;
}
@Override
public float eval(float[][] predicts, DMatrix dmat) {
float error = 0f;
float[] labels = dmat.getLabel();
int nrow = predicts.length;
for(int i=0; i<nrow; i++) {
if(labels[i]==0f && predicts[i][0]>0.5) {
error++;
}
else if(labels[i]==1f && predicts[i][0]<=0.5) {
error++;
}
}
return error/labels.length;
}
}

View File

@ -0,0 +1,129 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.demo.util;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.lang3.ArrayUtils;
/**
* util class for loading data
* @author hzx
*/
public class DataLoader {
public static class DenseData {
public float[] labels;
public float[] data;
public int nrow;
public int ncol;
}
public static class CSRSparseData {
public float[] labels;
public float[] data;
public long[] rowHeaders;
public int[] colIndex;
}
public static DenseData loadCSVFile(String filePath) throws FileNotFoundException, UnsupportedEncodingException, IOException {
DenseData denseData = new DenseData();
File f = new File(filePath);
FileInputStream in = new FileInputStream(f);
BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8"));
denseData.nrow = 0;
denseData.ncol = -1;
String line;
List<Float> tlabels = new ArrayList<>();
List<Float> tdata = new ArrayList<>();
while((line=reader.readLine()) != null) {
String[] items = line.trim().split(",");
if(items.length==0) {
continue;
}
denseData.nrow++;
if(denseData.ncol == -1) {
denseData.ncol = items.length - 1;
}
tlabels.add(Float.valueOf(items[items.length-1]));
for(int i=0; i<items.length-1; i++) {
tdata.add(Float.valueOf(items[i]));
}
}
reader.close();
in.close();
Float[] flabels = (Float[]) tlabels.toArray();
denseData.labels = ArrayUtils.toPrimitive(flabels);
Float[] fdata = (Float[]) tdata.toArray();
denseData.data = ArrayUtils.toPrimitive(fdata);
return denseData;
}
public static CSRSparseData loadSVMFile(String filePath) throws FileNotFoundException, UnsupportedEncodingException, IOException {
CSRSparseData spData = new CSRSparseData();
List<Float> tlabels = new ArrayList<>();
List<Float> tdata = new ArrayList<>();
List<Long> theaders = new ArrayList<>();
List<Integer> tindex = new ArrayList<>();
File f = new File(filePath);
FileInputStream in = new FileInputStream(f);
BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8"));
String line;
long rowheader = 0;
theaders.add(rowheader);
while((line=reader.readLine()) != null) {
String[] items = line.trim().split(" ");
if(items.length==0) {
continue;
}
rowheader += items.length - 1;
theaders.add(rowheader);
tlabels.add(Float.valueOf(items[0]));
for(int i=1; i<items.length; i++) {
String[] tup = items[i].split(":");
assert tup.length == 2;
tdata.add(Float.valueOf(tup[1]));
tindex.add(Integer.valueOf(tup[0]));
}
}
spData.labels = ArrayUtils.toPrimitive(tlabels.toArray(new Float[tlabels.size()]));
spData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata.size()]));
spData.colIndex = ArrayUtils.toPrimitive(tindex.toArray(new Integer[tindex.size()]));
spData.rowHeaders = ArrayUtils.toPrimitive(theaders.toArray(new Long[theaders.size()]));
return spData;
}
}

15
java/xgboost4j/LICENSE Normal file
View File

@ -0,0 +1,15 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

23
java/xgboost4j/README.md Normal file
View File

@ -0,0 +1,23 @@
# xgboost4j
this is a java wrapper for xgboost (https://github.com/dmlc/xgboost)
the structure of this wrapper is almost the same as the official python wrapper.
core of this wrapper is two classes:
* DMatrix for handling data
* Booster: for train and predict
## usage:
simple examples could be found in test package:
* Simple Train Example: org.dmlc.xgboost4j.TrainExample.java
* Simple Predict Example: org.dmlc.xgboost4j.PredictExample.java
* Cross Validation Example: org.dmlc.xgboost4j.example.CVExample.java
## native library:
only 64-bit linux/windows is supported now, if you want to build native wrapper library yourself, please refer to
https://github.com/yanqingmen/xgboost-java, and put your native library to the "./src/main/resources/lib" folder and replace the originals. (either "libxgboostjavawrapper.so" for linux or "xgboostjavawrapper.dll" for windows)

35
java/xgboost4j/pom.xml Normal file
View File

@ -0,0 +1,35 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.dmlc</groupId>
<artifactId>xgboost4j</artifactId>
<version>1.1</version>
<packaging>jar</packaging>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>1.7</maven.compiler.source>
<maven.compiler.target>1.7</maven.compiler.target>
</properties>
<reporting>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-javadoc-plugin</artifactId>
<version>2.10.3</version>
</plugin>
</plugins>
</reporting>
<dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.11</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>commons-logging</groupId>
<artifactId>commons-logging</artifactId>
<version>1.2</version>
</dependency>
</dependencies>
</project>

View File

@ -0,0 +1,438 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.UnsupportedEncodingException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.dmlc.xgboost4j.util.Initializer;
import org.dmlc.xgboost4j.util.Params;
import org.dmlc.xgboost4j.util.TransferUtil;
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
/**
* Booster for xgboost, similar to the python wrapper xgboost.py
* but custom obj function and eval function not supported at present.
* @author hzx
*/
public final class Booster {
private static final Log logger = LogFactory.getLog(Booster.class);
long handle = 0;
//load native library
static {
try {
Initializer.InitXgboost();
} catch (IOException ex) {
logger.error("load native library failed.");
logger.error(ex);
}
}
/**
* init Booster from dMatrixs
* @param params parameters
* @param dMatrixs DMatrix array
*/
public Booster(Params params, DMatrix[] dMatrixs) {
init(dMatrixs);
setParam("seed","0");
setParams(params);
}
/**
* load model from modelPath
* @param params parameters
* @param modelPath booster modelPath (model generated by booster.saveModel)
*/
public Booster(Params params, String modelPath) {
handle = XgboostJNI.XGBoosterCreate(new long[] {});
loadModel(modelPath);
setParam("seed","0");
setParams(params);
}
private void init(DMatrix[] dMatrixs) {
long[] handles = null;
if(dMatrixs != null) {
handles = TransferUtil.dMatrixs2handles(dMatrixs);
}
handle = XgboostJNI.XGBoosterCreate(handles);
}
/**
* set parameter
* @param key param name
* @param value param value
*/
public final void setParam(String key, String value) {
XgboostJNI.XGBoosterSetParam(handle, key, value);
}
/**
* set parameters
* @param params parameters key-value map
*/
public void setParams(Params params) {
if(params!=null) {
for(Map.Entry<String, String> entry : params) {
setParam(entry.getKey(), entry.getValue());
}
}
}
/**
* Update (one iteration)
* @param dtrain training data
* @param iter current iteration number
*/
public void update(DMatrix dtrain, int iter) {
XgboostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle());
}
/**
* update with customize obj func
* @param dtrain training data
* @param iter current iteration number
* @param obj customized objective class
*/
public void update(DMatrix dtrain, int iter, IObjective obj) {
float[][] predicts = predict(dtrain, true);
List<float[]> gradients = obj.getGradient(predicts, dtrain);
boost(dtrain, gradients.get(0), gradients.get(1));
}
/**
* update with give grad and hess
* @param dtrain training data
* @param grad first order of gradient
* @param hess seconde order of gradient
*/
public void boost(DMatrix dtrain, float[] grad, float[] hess) {
if(grad.length != hess.length) {
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length, hess.length));
}
XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad, hess);
}
/**
* evaluate with given dmatrixs.
* @param evalMatrixs dmatrixs for evaluation
* @param evalNames name for eval dmatrixs, used for check results
* @param iter current eval iteration
* @return eval information
*/
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) {
long[] handles = TransferUtil.dMatrixs2handles(evalMatrixs);
String evalInfo = XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames);
return evalInfo;
}
/**
* evaluate with given customized Evaluation class
* @param evalMatrixs
* @param evalNames
* @param iter
* @param eval
* @return eval information
*/
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter, IEvaluation eval) {
String evalInfo = "";
for(int i=0; i<evalNames.length; i++) {
String evalName = evalNames[i];
DMatrix evalMat = evalMatrixs[i];
float evalResult = eval.eval(predict(evalMat), evalMat);
String evalMetric = eval.getMetric();
evalInfo += String.format("\t%s-%s:%f", evalName,evalMetric, evalResult);
}
return evalInfo;
}
/**
* evaluate with given dmatrix handles;
* @param dHandles evaluation data handles
* @param evalNames name for eval dmatrixs, used for check results
* @param iter current eval iteration
* @return eval information
*/
public String evalSet(long[] dHandles, String[] evalNames, int iter) {
String evalInfo = XgboostJNI.XGBoosterEvalOneIter(handle, iter, dHandles, evalNames);
return evalInfo;
}
/**
* evaluate with given dmatrix, similar to evalSet
* @param evalMat
* @param evalName
* @param iter
* @return eval information
*/
public String eval(DMatrix evalMat, String evalName, int iter) {
DMatrix[] evalMats = new DMatrix[] {evalMat};
String[] evalNames = new String[] {evalName};
return evalSet(evalMats, evalNames, iter);
}
/**
* base function for Predict
* @param data
* @param outPutMargin
* @param treeLimit
* @param predLeaf
* @return predict results
*/
private synchronized float[][] pred(DMatrix data, boolean outPutMargin, long treeLimit, boolean predLeaf) {
int optionMask = 0;
if(outPutMargin) {
optionMask = 1;
}
if(predLeaf) {
optionMask = 2;
}
float[] rawPredicts = XgboostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask, treeLimit);
int row = (int) data.rowNum();
int col = (int) rawPredicts.length/row;
float[][] predicts = new float[row][col];
int r,c;
for(int i=0; i< rawPredicts.length; i++) {
r = i/col;
c = i%col;
predicts[r][c] = rawPredicts[i];
}
return predicts;
}
/**
* Predict with data
* @param data dmatrix storing the input
* @return predict result
*/
public float[][] predict(DMatrix data) {
return pred(data, false, 0, false);
}
/**
* Predict with data
* @param data dmatrix storing the input
* @param outPutMargin Whether to output the raw untransformed margin value.
* @return predict result
*/
public float[][] predict(DMatrix data, boolean outPutMargin) {
return pred(data, outPutMargin, 0, false);
}
/**
* Predict with data
* @param data dmatrix storing the input
* @param outPutMargin Whether to output the raw untransformed margin value.
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
* @return predict result
*/
public float[][] predict(DMatrix data, boolean outPutMargin, long treeLimit) {
return pred(data, outPutMargin, treeLimit, false);
}
/**
* Predict with data
* @param data dmatrix storing the input
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
* @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees), nsample = data.numRow
with each record indicating the predicted leaf index of each sample in each tree.
Note that the leaf index of a tree is unique per tree, so you may find leaf 1
in both tree 1 and tree 0.
* @return predict result
*/
public float[][] predict(DMatrix data , long treeLimit, boolean predLeaf) {
return pred(data, false, treeLimit, predLeaf);
}
/**
* save model to modelPath
* @param modelPath
*/
public void saveModel(String modelPath) {
XgboostJNI.XGBoosterSaveModel(handle, modelPath);
}
private void loadModel(String modelPath) {
XgboostJNI.XGBoosterLoadModel(handle, modelPath);
}
/**
* get the dump of the model as a string array
* @param withStats Controls whether the split statistics are output.
* @return dumped model information
*/
public String[] getDumpInfo(boolean withStats) {
int statsFlag = 0;
if(withStats) {
statsFlag = 1;
}
String[] modelInfos = XgboostJNI.XGBoosterDumpModel(handle, "", statsFlag);
return modelInfos;
}
/**
* get the dump of the model as a string array
* @param featureMap featureMap file
* @param withStats Controls whether the split statistics are output.
* @return dumped model information
*/
public String[] getDumpInfo(String featureMap, boolean withStats) {
int statsFlag = 0;
if(withStats) {
statsFlag = 1;
}
String[] modelInfos = XgboostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag);
return modelInfos;
}
/**
* Dump model into a text file.
* @param modelPath file to save dumped model info
* @param withStats bool
Controls whether the split statistics are output.
* @throws FileNotFoundException
* @throws UnsupportedEncodingException
* @throws IOException
*/
public void dumpModel(String modelPath, boolean withStats) throws FileNotFoundException, UnsupportedEncodingException, IOException {
File tf = new File(modelPath);
FileOutputStream out = new FileOutputStream(tf);
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
String[] modelInfos = getDumpInfo(withStats);
for(int i=0; i<modelInfos.length; i++) {
writer.write("booster [" + i +"]:\n");
writer.write(modelInfos[i]);
}
writer.close();
out.close();
}
/**
* Dump model into a text file.
* @param modelPath file to save dumped model info
* @param featureMap featureMap file
* @param withStats bool
Controls whether the split statistics are output.
* @throws FileNotFoundException
* @throws UnsupportedEncodingException
* @throws IOException
*/
public void dumpModel(String modelPath, String featureMap, boolean withStats) throws FileNotFoundException, UnsupportedEncodingException, IOException {
File tf = new File(modelPath);
FileOutputStream out = new FileOutputStream(tf);
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
String[] modelInfos = getDumpInfo(featureMap, withStats);
for(int i=0; i<modelInfos.length; i++) {
writer.write("booster [" + i +"]:\n");
writer.write(modelInfos[i]);
}
writer.close();
out.close();
}
/**
* get importance of each feature
* @return featureMap key: feature index, value: feature importance score
*/
public Map<String, Integer> getFeatureScore() {
String[] modelInfos = getDumpInfo(false);
Map<String, Integer> featureScore = new HashMap<>();
for(String tree : modelInfos) {
for(String node : tree.split("\n")) {
String[] array = node.split("\\[");
if(array.length == 1) {
continue;
}
String fid = array[1].split("\\]")[0];
fid = fid.split("<")[0];
if(featureScore.containsKey(fid)) {
featureScore.put(fid, 1 + featureScore.get(fid));
}
else {
featureScore.put(fid, 1);
}
}
}
return featureScore;
}
/**
* get importance of each feature
* @param featureMap file to save dumped model info
* @return featureMap key: feature index, value: feature importance score
*/
public Map<String, Integer> getFeatureScore(String featureMap) {
String[] modelInfos = getDumpInfo(featureMap, false);
Map<String, Integer> featureScore = new HashMap<>();
for(String tree : modelInfos) {
for(String node : tree.split("\n")) {
String[] array = node.split("\\[");
if(array.length == 1) {
continue;
}
String fid = array[1].split("\\]")[0];
fid = fid.split("<")[0];
if(featureScore.containsKey(fid)) {
featureScore.put(fid, 1 + featureScore.get(fid));
}
else {
featureScore.put(fid, 1);
}
}
}
return featureScore;
}
@Override
protected void finalize() {
delete();
}
public synchronized void delete() {
if(handle != 0l) {
XgboostJNI.XGBoosterFree(handle);
handle=0;
}
}
}

View File

@ -0,0 +1,217 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j;
import java.io.IOException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.dmlc.xgboost4j.util.Initializer;
import org.dmlc.xgboost4j.util.TransferUtil;
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
/**
* DMatrix for xgboost, similar to the python wrapper xgboost.py
* @author hzx
*/
public class DMatrix {
private static final Log logger = LogFactory.getLog(DMatrix.class);
long handle = 0;
//load native library
static {
try {
Initializer.InitXgboost();
} catch (IOException ex) {
logger.error("load native library failed.");
logger.error(ex);
}
}
/**
* sparse matrix type (CSR or CSC)
*/
public static enum SparseType {
CSR,
CSC;
}
/**
* init DMatrix from file (svmlight format)
* @param dataPath
*/
public DMatrix(String dataPath) {
handle = XgboostJNI.XGDMatrixCreateFromFile(dataPath, 1);
}
/**
* create DMatrix from sparse matrix
* @param headers index to headers (rowHeaders for CSR or colHeaders for CSC)
* @param indices Indices (colIndexs for CSR or rowIndexs for CSC)
* @param data non zero values (sequence by row for CSR or by col for CSC)
* @param st sparse matrix type (CSR or CSC)
*/
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) {
if(st == SparseType.CSR) {
handle = XgboostJNI.XGDMatrixCreateFromCSR(headers, indices, data);
}
else if(st == SparseType.CSC) {
handle = XgboostJNI.XGDMatrixCreateFromCSC(headers, indices, data);
}
else {
throw new UnknownError("unknow sparsetype");
}
}
/**
* create DMatrix from dense matrix
* @param data data values
* @param nrow number of rows
* @param ncol number of columns
*/
public DMatrix(float[] data, int nrow, int ncol) {
handle = XgboostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f);
}
/**
* used for DMatrix slice
* @param handle
*/
private DMatrix(long handle) {
this.handle = handle;
}
/**
* set label of dmatrix
* @param labels
*/
public void setLabel(float[] labels) {
XgboostJNI.XGDMatrixSetFloatInfo(handle, "label", labels);
}
/**
* set weight of each instance
* @param weights
*/
public void setWeight(float[] weights) {
XgboostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights);
}
/**
* if specified, xgboost will start from this init margin
* can be used to specify initial prediction to boost from
* @param baseMargin
*/
public void setBaseMargin(float[] baseMargin) {
XgboostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin);
}
/**
* if specified, xgboost will start from this init margin
* can be used to specify initial prediction to boost from
* @param baseMargin
*/
public void setBaseMargin(float[][] baseMargin) {
float[] flattenMargin = TransferUtil.flatten(baseMargin);
setBaseMargin(flattenMargin);
}
/**
* Set group sizes of DMatrix (used for ranking)
* @param group
*/
public void setGroup(int[] group) {
XgboostJNI.XGDMatrixSetGroup(handle, group);
}
private float[] getFloatInfo(String field) {
float[] infos = XgboostJNI.XGDMatrixGetFloatInfo(handle, field);
return infos;
}
private int[] getIntInfo(String field) {
int[] infos = XgboostJNI.XGDMatrixGetUIntInfo(handle, field);
return infos;
}
/**
* get label values
* @return label
*/
public float[] getLabel() {
return getFloatInfo("label");
}
/**
* get weight of the DMatrix
* @return weights
*/
public float[] getWeight() {
return getFloatInfo("weight");
}
/**
* get base margin of the DMatrix
* @return base margin
*/
public float[] getBaseMargin() {
return getFloatInfo("base_margin");
}
/**
* Slice the DMatrix and return a new DMatrix that only contains `rowIndex`.
* @param rowIndex
* @return sliced new DMatrix
*/
public DMatrix slice(int[] rowIndex) {
long sHandle = XgboostJNI.XGDMatrixSliceDMatrix(handle, rowIndex);
DMatrix sMatrix = new DMatrix(sHandle);
return sMatrix;
}
/**
* get the row number of DMatrix
* @return number of rows
*/
public long rowNum() {
return XgboostJNI.XGDMatrixNumRow(handle);
}
/**
* save DMatrix to filePath
* @param filePath
*/
public void saveBinary(String filePath) {
XgboostJNI.XGDMatrixSaveBinary(handle, filePath, 1);
}
public long getHandle() {
return handle;
}
@Override
protected void finalize() {
delete();
}
public synchronized void delete() {
if(handle != 0) {
XgboostJNI.XGDMatrixFree(handle);
handle = 0;
}
}
}

View File

@ -0,0 +1,36 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j;
/**
* interface for customized evaluation
* @author hzx
*/
public interface IEvaluation {
/**
* get evaluate metric
* @return evalMetric
*/
public abstract String getMetric();
/**
* evaluate with predicts and data
* @param predicts
* @param dmat
* @return
*/
public abstract float eval(float[][] predicts, DMatrix dmat);
}

View File

@ -0,0 +1,32 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j;
import java.util.List;
/**
* interface for customize Object function
* @author hzx
*/
public interface IObjective {
/**
* user define objective function, return gradient and second order gradient
* @param predicts untransformed margin predicts
* @param dtrain training data
* @return List with two float array, correspond to first order grad and second order grad
*/
public abstract List<float[]> getGradient(float[][] predicts, DMatrix dtrain);
}

View File

@ -0,0 +1,85 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.util;
import org.dmlc.xgboost4j.IEvaluation;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.IObjective;
/**
* cross validation package for xgb
* @author hzx
*/
public class CVPack {
DMatrix dtrain;
DMatrix dtest;
DMatrix[] dmats;
long[] dataArray;
String[] names;
Booster booster;
/**
* create an cross validation package
* @param dtrain train data
* @param dtest test data
* @param params parameters
*/
public CVPack(DMatrix dtrain, DMatrix dtest, Params params) {
dmats = new DMatrix[] {dtrain, dtest};
booster = new Booster(params, dmats);
dataArray = TransferUtil.dMatrixs2handles(dmats);
names = new String[] {"train", "test"};
this.dtrain = dtrain;
this.dtest = dtest;
}
/**
* update one iteration
* @param iter iteration num
*/
public void update(int iter) {
booster.update(dtrain, iter);
}
/**
* update one iteration
* @param iter iteration num
* @param obj customized objective
*/
public void update(int iter, IObjective obj) {
booster.update(dtrain, iter, obj);
}
/**
* evaluation
* @param iter iteration num
* @return
*/
public String eval(int iter) {
return booster.evalSet(dataArray, names, iter);
}
/**
* evaluation
* @param iter iteration num
* @param eval customized eval
* @return
*/
public String eval(int iter, IEvaluation eval) {
return booster.evalSet(dmats, names, iter, eval);
}
}

View File

@ -0,0 +1,92 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.util;
import java.io.IOException;
import java.lang.reflect.Field;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
/**
* class to load native library
* @author hzx
*/
public class Initializer {
private static final Log logger = LogFactory.getLog(Initializer.class);
static boolean initialized = false;
public static final String nativePath = "./lib";
public static final String nativeResourcePath = "/lib/";
public static final String[] libNames = new String[] {"xgboostjavawrapper"};
public static synchronized void InitXgboost() throws IOException {
if(initialized == false) {
for(String libName: libNames) {
smartLoad(libName);
}
initialized = true;
}
}
/**
* load native library, this method will first try to load library from java.library.path, then try to load from library in jar package.
* @param libName
* @throws IOException
*/
private static void smartLoad(String libName) throws IOException {
addNativeDir(nativePath);
try {
System.loadLibrary(libName);
}
catch (UnsatisfiedLinkError e) {
try {
NativeUtils.loadLibraryFromJar(nativeResourcePath + System.mapLibraryName(libName));
}
catch (IOException e1) {
throw e1;
}
}
}
/**
* add libPath to java.library.path, then native library in libPath would be load properly
* @param libPath
* @throws IOException
*/
public static void addNativeDir(String libPath) throws IOException {
try {
Field field = ClassLoader.class.getDeclaredField("usr_paths");
field.setAccessible(true);
String[] paths = (String[]) field.get(null);
for (String path : paths) {
if (libPath.equals(path)) {
return;
}
}
String[] tmp = new String[paths.length+1];
System.arraycopy(paths,0,tmp,0,paths.length);
tmp[paths.length] = libPath;
field.set(null, tmp);
} catch (IllegalAccessException e) {
logger.error(e.getMessage());
throw new IOException("Failed to get permissions to set library path");
} catch (NoSuchFieldException e) {
logger.error(e.getMessage());
throw new IOException("Failed to get field handle to set library path");
}
}
}

View File

@ -0,0 +1,99 @@
/*
* To change this license header, choose License Headers in Project Properties.
* To change this template file, choose Tools | Templates
* and open the template in the editor.
*/
package org.dmlc.xgboost4j.util;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
/**
* Simple library class for working with JNI (Java Native Interface)
*
* @see http://adamheinrich.com/2012/how-to-load-native-jni-library-from-jar
*
* @author Adam Heirnich &lt;adam@adamh.cz&gt;, http://www.adamh.cz
*/
public class NativeUtils {
/**
* Private constructor - this class will never be instanced
*/
private NativeUtils() {
}
/**
* Loads library from current JAR archive
*
* The file from JAR is copied into system temporary directory and then loaded. The temporary file is deleted after exiting.
* Method uses String as filename because the pathname is "abstract", not system-dependent.
*
* @param path The filename inside JAR as absolute path (beginning with '/'), e.g. /package/File.ext
* @throws IOException If temporary file creation or read/write operation fails
* @throws IllegalArgumentException If source file (param path) does not exist
* @throws IllegalArgumentException If the path is not absolute or if the filename is shorter than three characters (restriction of {@see File#createTempFile(java.lang.String, java.lang.String)}).
*/
public static void loadLibraryFromJar(String path) throws IOException {
if (!path.startsWith("/")) {
throw new IllegalArgumentException("The path has to be absolute (start with '/').");
}
// Obtain filename from path
String[] parts = path.split("/");
String filename = (parts.length > 1) ? parts[parts.length - 1] : null;
// Split filename to prexif and suffix (extension)
String prefix = "";
String suffix = null;
if (filename != null) {
parts = filename.split("\\.", 2);
prefix = parts[0];
suffix = (parts.length > 1) ? "."+parts[parts.length - 1] : null; // Thanks, davs! :-)
}
// Check if the filename is okay
if (filename == null || prefix.length() < 3) {
throw new IllegalArgumentException("The filename has to be at least 3 characters long.");
}
// Prepare temporary file
File temp = File.createTempFile(prefix, suffix);
temp.deleteOnExit();
if (!temp.exists()) {
throw new FileNotFoundException("File " + temp.getAbsolutePath() + " does not exist.");
}
// Prepare buffer for data copying
byte[] buffer = new byte[1024];
int readBytes;
// Open and check input stream
InputStream is = NativeUtils.class.getResourceAsStream(path);
if (is == null) {
throw new FileNotFoundException("File " + path + " was not found inside JAR.");
}
// Open output stream and copy data between source file in JAR and the temporary file
OutputStream os = new FileOutputStream(temp);
try {
while ((readBytes = is.read(buffer)) != -1) {
os.write(buffer, 0, readBytes);
}
} finally {
// If read/write fails, close streams safely before throwing an exception
os.close();
is.close();
}
// Finally, load the library
System.load(temp.getAbsolutePath());
}
}

View File

@ -0,0 +1,54 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.util;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map.Entry;
import java.util.AbstractMap;
/**
* a util class for handle params
* @author hzx
*/
public class Params implements Iterable<Entry<String, String>>{
List<Entry<String, String>> params = new ArrayList<>();
/**
* put param key-value pair
* @param key
* @param value
*/
public void put(String key, String value) {
params.add(new AbstractMap.SimpleEntry<>(key, value));
}
@Override
public String toString(){
String paramsInfo = "";
for(Entry<String, String> param : params) {
paramsInfo += param.getKey() + ":" + param.getValue() + "\n";
}
return paramsInfo;
}
@Override
public Iterator<Entry<String, String>> iterator() {
return params.iterator();
}
}

View File

@ -0,0 +1,230 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.util;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.dmlc.xgboost4j.IEvaluation;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.IObjective;
/**
* trainer for xgboost
* @author hzx
*/
public class Trainer {
private static final Log logger = LogFactory.getLog(Trainer.class);
/**
* Train a booster with given parameters.
* @param params Booster params.
* @param dtrain Data to be trained.
* @param round Number of boosting iterations.
* @param evalMats Data to be evaluated (may include dtrain)
* @param evalNames name of data (used for evaluation info)
* @param obj customized objective (set to null if not used)
* @param eval customized evaluation (set to null if not used)
* @return trained booster
*/
public static Booster train(Params params, DMatrix dtrain, int round,
DMatrix[] evalMats, String[] evalNames, IObjective obj, IEvaluation eval) {
//collect all data matrixs
DMatrix[] allMats;
if(evalMats!=null && evalMats.length>0) {
allMats = new DMatrix[evalMats.length+1];
allMats[0] = dtrain;
System.arraycopy(evalMats, 0, allMats, 1, evalMats.length);
}
else {
allMats = new DMatrix[1];
allMats[0] = dtrain;
}
//initialize booster
Booster booster = new Booster(params, allMats);
//used for evaluation
long[] dataArray = null;
String[] names = null;
if(dataArray==null || names==null) {
//prepare data for evaluation
dataArray = TransferUtil.dMatrixs2handles(evalMats);
names = evalNames;
}
//begin to train
for(int iter=0; iter<round; iter++) {
if(obj != null) {
booster.update(dtrain, iter, obj);
} else {
booster.update(dtrain, iter);
}
//evaluation
if(evalMats!=null && evalMats.length>0) {
String evalInfo;
if(eval != null) {
evalInfo = booster.evalSet(evalMats, evalNames, iter, eval);
}
else {
evalInfo = booster.evalSet(dataArray, names, iter);
}
logger.info(evalInfo);
}
}
return booster;
}
/**
* Cross-validation with given paramaters.
* @param params Booster params.
* @param data Data to be trained.
* @param round Number of boosting iterations.
* @param nfold Number of folds in CV.
* @param metrics Evaluation metrics to be watched in CV.
* @param obj customized objective (set to null if not used)
* @param eval customized evaluation (set to null if not used)
* @return evaluation history
*/
public static String[] crossValiation(Params params, DMatrix data, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) {
CVPack[] cvPacks = makeNFold(data, nfold, params, metrics);
String[] evalHist = new String[round];
String[] results = new String[cvPacks.length];
for(int i=0; i<round; i++) {
for(CVPack cvPack : cvPacks) {
if(obj != null) {
cvPack.update(i, obj);
}
else {
cvPack.update(i);
}
}
for(int j=0; j<cvPacks.length; j++) {
if(eval != null) {
results[j] = cvPacks[j].eval(i, eval);
}
else {
results[j] = cvPacks[j].eval(i);
}
}
evalHist[i] = aggCVResults(results);
logger.info(evalHist[i]);
}
return evalHist;
}
/**
* make an n-fold array of CVPack from random indices
* @param data original data
* @param nfold num of folds
* @param params booster parameters
* @param evalMetrics Evaluation metrics
* @return CV package array
*/
public static CVPack[] makeNFold(DMatrix data, int nfold, Params params, String[] evalMetrics) {
List<Integer> samples = genRandPermutationNums(0, (int) data.rowNum());
int step = samples.size()/nfold;
int[] testSlice = new int[step];
int[] trainSlice = new int[samples.size()-step];
int testid, trainid;
CVPack[] cvPacks = new CVPack[nfold];
for(int i=0; i<nfold; i++) {
testid = 0;
trainid = 0;
for(int j=0; j<samples.size(); j++) {
if(j>(i*step) && j<(i*step+step) && testid<step) {
testSlice[testid] = samples.get(j);
testid++;
}
else{
if(trainid<samples.size()-step) {
trainSlice[trainid] = samples.get(j);
trainid++;
}
else {
testSlice[testid] = samples.get(j);
testid++;
}
}
}
DMatrix dtrain = data.slice(trainSlice);
DMatrix dtest = data.slice(testSlice);
CVPack cvPack = new CVPack(dtrain, dtest, params);
//set eval types
if(evalMetrics!=null) {
for(String type : evalMetrics) {
cvPack.booster.setParam("eval_metric", type);
}
}
cvPacks[i] = cvPack;
}
return cvPacks;
}
private static List<Integer> genRandPermutationNums(int start, int end) {
List<Integer> samples = new ArrayList<>();
for(int i=start; i<end; i++) {
samples.add(i);
}
Collections.shuffle(samples);
return samples;
}
/**
* Aggregate cross-validation results.
* @param results eval info from each data sample
* @return cross-validation eval info
*/
public static String aggCVResults(String[] results) {
Map<String, List<Float> > cvMap = new HashMap<>();
String aggResult = results[0].split("\t")[0];
for(String result : results) {
String[] items = result.split("\t");
for(int i=1; i<items.length; i++) {
String[] tup = items[i].split(":");
String key = tup[0];
Float value = Float.valueOf(tup[1]);
if(!cvMap.containsKey(key)) {
cvMap.put(key, new ArrayList<Float>());
}
cvMap.get(key).add(value);
}
}
for(String key : cvMap.keySet()) {
float value = 0f;
for(Float tvalue : cvMap.get(key)) {
value += tvalue;
}
value /= cvMap.get(key).size();
aggResult += String.format("\tcv-%s:%f", key, value);
}
return aggResult;
}
}

View File

@ -0,0 +1,55 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.util;
import org.dmlc.xgboost4j.DMatrix;
/**
*
* @author hzx
*/
public class TransferUtil {
/**
* transfer DMatrix array to handle array (used for native functions)
* @param dmatrixs
* @return handle array for input dmatrixs
*/
public static long[] dMatrixs2handles(DMatrix[] dmatrixs) {
long[] handles = new long[dmatrixs.length];
for(int i=0; i<dmatrixs.length; i++) {
handles[i] = dmatrixs[i].getHandle();
}
return handles;
}
/**
* flatten a mat to array
* @param mat
* @return
*/
public static float[] flatten(float[][] mat) {
int size = 0;
for (float[] array : mat) size += array.length;
float[] result = new float[size];
int pos = 0;
for (float[] ar : mat) {
System.arraycopy(ar, 0, result, pos, ar.length);
pos += ar.length;
}
return result;
}
}

View File

@ -0,0 +1,48 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.wrapper;
/**
* xgboost jni wrapper functions for xgboost_wrapper.h
* @author hzx
*/
public class XgboostJNI {
public final static native long XGDMatrixCreateFromFile(String fname, int silent);
public final static native long XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data);
public final static native long XGDMatrixCreateFromCSC(long[] colptr, int[] indices, float[] data);
public final static native long XGDMatrixCreateFromMat(float[] data, int nrow, int ncol, float missing);
public final static native long XGDMatrixSliceDMatrix(long handle, int[] idxset);
public final static native void XGDMatrixFree(long handle);
public final static native void XGDMatrixSaveBinary(long handle, String fname, int silent);
public final static native void XGDMatrixSetFloatInfo(long handle, String field, float[] array);
public final static native void XGDMatrixSetUIntInfo(long handle, String field, int[] array);
public final static native void XGDMatrixSetGroup(long handle, int[] group);
public final static native float[] XGDMatrixGetFloatInfo(long handle, String field);
public final static native int[] XGDMatrixGetUIntInfo(long handle, String filed);
public final static native long XGDMatrixNumRow(long handle);
public final static native long XGBoosterCreate(long[] handles);
public final static native void XGBoosterFree(long handle);
public final static native void XGBoosterSetParam(long handle, String name, String value);
public final static native void XGBoosterUpdateOneIter(long handle, int iter, long dtrain);
public final static native void XGBoosterBoostOneIter(long handle, long dtrain, float[] grad, float[] hess);
public final static native String XGBoosterEvalOneIter(long handle, int iter, long[] dmats, String[] evnames);
public final static native float[] XGBoosterPredict(long handle, long dmat, int option_mask, long ntree_limit);
public final static native void XGBoosterLoadModel(long handle, String fname);
public final static native void XGBoosterSaveModel(long handle, String fname);
public final static native void XGBoosterLoadModelFromBuffer(long handle, long buf, long len);
public final static native String XGBoosterGetModelRaw(long handle);
public final static native String[] XGBoosterDumpModel(long handle, String fmap, int with_stats);
}

View File

@ -0,0 +1 @@
please put native library in this package.

634
java/xgboost4j_wrapper.cpp Normal file
View File

@ -0,0 +1,634 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <jni.h>
#include "../wrapper/xgboost_wrapper.h"
#include "xgboost4j_wrapper.h"
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromFile
(JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent) {
jlong jresult = 0 ;
char *fname = (char *) 0 ;
int silent;
void *result = 0 ;
fname = 0;
if (jfname) {
fname = (char *)jenv->GetStringUTFChars(jfname, 0);
if (!fname) return 0;
}
silent = (int)jsilent;
result = (void *)XGDMatrixCreateFromFile((char const *)fname, silent);
*(void **)&jresult = result;
if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname);
return jresult;
}
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixCreateFromCSR
* Signature: ([J[J[F)J
*/
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSR
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata) {
jlong jresult = 0 ;
bst_ulong nindptr ;
bst_ulong nelem;
void *result = 0 ;
jlong* indptr = jenv->GetLongArrayElements(jindptr, 0);
jint* indices = jenv->GetIntArrayElements(jindices, 0);
jfloat* data = jenv->GetFloatArrayElements(jdata, 0);
nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
nelem = (bst_ulong)jenv->GetArrayLength(jdata);
result = (void *)XGDMatrixCreateFromCSR((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem);
*(void **)&jresult = result;
//release
jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
jenv->ReleaseIntArrayElements(jindices, indices, 0);
jenv->ReleaseFloatArrayElements(jdata, data, 0);
return jresult;
}
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixCreateFromCSC
* Signature: ([J[J[F)J
*/
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSC
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata) {
jlong jresult = 0 ;
bst_ulong nindptr ;
bst_ulong nelem;
void *result = 0 ;
jlong* indptr = jenv->GetLongArrayElements(jindptr, NULL);
jint* indices = jenv->GetIntArrayElements(jindices, 0);
jfloat* data = jenv->GetFloatArrayElements(jdata, NULL);
nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
nelem = (bst_ulong)jenv->GetArrayLength(jdata);
result = (void *)XGDMatrixCreateFromCSC((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem);
*(void **)&jresult = result;
//release
jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
jenv->ReleaseIntArrayElements(jindices, indices, 0);
jenv->ReleaseFloatArrayElements(jdata, data, 0);
return jresult;
}
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixCreateFromMat
* Signature: ([FIIF)J
*/
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromMat
(JNIEnv *jenv, jclass jcls, jfloatArray jdata, jint jnrow, jint jncol, jfloat jmiss) {
jlong jresult = 0 ;
bst_ulong nrow ;
bst_ulong ncol ;
float miss ;
void *result = 0 ;
jfloat* data = jenv->GetFloatArrayElements(jdata, 0);
nrow = (bst_ulong)jnrow;
ncol = (bst_ulong)jncol;
miss = (float)jmiss;
result = (void *)XGDMatrixCreateFromMat((float const *)data, nrow, ncol, miss);
*(void **)&jresult = result;
//release
jenv->ReleaseFloatArrayElements(jdata, data, 0);
return jresult;
}
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixSliceDMatrix
* Signature: (J[I)J
*/
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSliceDMatrix
(JNIEnv *jenv, jclass jcls, jlong jhandle, jintArray jindexset) {
jlong jresult = 0 ;
void *handle = (void *) 0 ;
bst_ulong len;
void *result = 0 ;
jint* indexset = jenv->GetIntArrayElements(jindexset, 0);
handle = *(void **)&jhandle;
len = (bst_ulong)jenv->GetArrayLength(jindexset);
result = (void *)XGDMatrixSliceDMatrix(handle, (int const *)indexset, len);
*(void **)&jresult = result;
//release
jenv->ReleaseIntArrayElements(jindexset, indexset, 0);
return jresult;
}
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixFree
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixFree
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
void *handle = (void *) 0 ;
handle = *(void **)&jhandle;
XGDMatrixFree(handle);
}
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixSaveBinary
* Signature: (JLjava/lang/String;I)V
*/
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSaveBinary
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname, jint jsilent) {
void *handle = (void *) 0 ;
char *fname = (char *) 0 ;
int silent ;
handle = *(void **)&jhandle;
fname = 0;
if (jfname) {
fname = (char *)jenv->GetStringUTFChars(jfname, 0);
if (!fname) return ;
}
silent = (int)jsilent;
XGDMatrixSaveBinary(handle, (char const *)fname, silent);
if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname);
}
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixSetFloatInfo
* Signature: (JLjava/lang/String;[F)V
*/
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetFloatInfo
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jfloatArray jarray) {
void *handle = (void *) 0 ;
char *field = (char *) 0 ;
bst_ulong len;
handle = *(void **)&jhandle;
field = 0;
if (jfield) {
field = (char *)jenv->GetStringUTFChars(jfield, 0);
if (!field) return ;
}
jfloat* array = jenv->GetFloatArrayElements(jarray, NULL);
len = (bst_ulong)jenv->GetArrayLength(jarray);
XGDMatrixSetFloatInfo(handle, (char const *)field, (float const *)array, len);
//release
if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field);
jenv->ReleaseFloatArrayElements(jarray, array, 0);
}
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixSetUIntInfo
* Signature: (JLjava/lang/String;[I)V
*/
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetUIntInfo
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jintArray jarray) {
void *handle = (void *) 0 ;
char *field = (char *) 0 ;
bst_ulong len ;
handle = *(void **)&jhandle;
field = 0;
if (jfield) {
field = (char *)jenv->GetStringUTFChars(jfield, 0);
if (!field) return ;
}
jint* array = jenv->GetIntArrayElements(jarray, NULL);
len = (bst_ulong)jenv->GetArrayLength(jarray);
XGDMatrixSetUIntInfo(handle, (char const *)field, (unsigned int const *)array, len);
//release
if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field);
jenv->ReleaseIntArrayElements(jarray, array, 0);
}
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixSetGroup
* Signature: (J[I)V
*/
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGroup
(JNIEnv * jenv, jclass jcls, jlong jhandle, jintArray jarray) {
void *handle = (void *) 0 ;
bst_ulong len ;
handle = *(void **)&jhandle;
jint* array = jenv->GetIntArrayElements(jarray, NULL);
len = (bst_ulong)jenv->GetArrayLength(jarray);
XGDMatrixSetGroup(handle, (unsigned int const *)array, len);
//release
jenv->ReleaseIntArrayElements(jarray, array, 0);
}
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixGetFloatInfo
* Signature: (JLjava/lang/String;)[F
*/
JNIEXPORT jfloatArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetFloatInfo
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield) {
void *handle = (void *) 0 ;
char *field = (char *) 0 ;
bst_ulong len[1];
*len = 0;
float *result = 0 ;
handle = *(void **)&jhandle;
field = 0;
if (jfield) {
field = (char *)jenv->GetStringUTFChars(jfield, 0);
if (!field) return 0;
}
result = (float *)XGDMatrixGetFloatInfo((void const *)handle, (char const *)field, len);
if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field);
jsize jlen = (jsize)*len;
jfloatArray jresult = jenv->NewFloatArray(jlen);
jenv->SetFloatArrayRegion(jresult, 0, jlen, (jfloat *)result);
return jresult;
}
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixGetUIntInfo
* Signature: (JLjava/lang/String;)[I
*/
JNIEXPORT jintArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetUIntInfo
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield) {
void *handle = (void *) 0 ;
char *field = (char *) 0 ;
bst_ulong len[1];
*len = 0;
unsigned int *result = 0 ;
handle = *(void **)&jhandle;
field = 0;
if (jfield) {
field = (char *)jenv->GetStringUTFChars(jfield, 0);
if (!field) return 0;
}
result = (unsigned int *)XGDMatrixGetUIntInfo((void const *)handle, (char const *)field, len);
if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field);
jsize jlen = (jsize)*len;
jintArray jresult = jenv->NewIntArray(jlen);
jenv->SetIntArrayRegion(jresult, 0, jlen, (jint *)result);
return jresult;
}
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixNumRow
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixNumRow
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
jlong jresult = 0 ;
void *handle = (void *) 0 ;
bst_ulong result;
handle = *(void **)&jhandle;
result = (bst_ulong)XGDMatrixNumRow((void const *)handle);
jresult = (jlong)result;
return jresult;
}
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterCreate
* Signature: ([J)J
*/
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCreate
(JNIEnv *jenv, jclass jcls, jlongArray jhandles) {
jlong jresult = 0 ;
void **handles = 0;
bst_ulong len = 0;
void *result = 0 ;
jlong* cjhandles = 0;
if(jhandles) {
len = (bst_ulong)jenv->GetArrayLength(jhandles);
handles = new void*[len];
//put handle from jhandles to chandles
cjhandles = jenv->GetLongArrayElements(jhandles, 0);
for(bst_ulong i=0; i<len; i++) {
handles[i] = *(void **)&cjhandles[i];
}
}
result = (void *)XGBoosterCreate(handles, len);
//release
if(jhandles) {
delete[] handles;
jenv->ReleaseLongArrayElements(jhandles, cjhandles, 0);
}
*(void **)&jresult = result;
return jresult;
}
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterFree
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterFree
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
void *handle = (void *) 0 ;
handle = *(void **)&jhandle;
XGBoosterFree(handle);
}
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterSetParam
* Signature: (JLjava/lang/String;Ljava/lang/String;)V
*/
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSetParam
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jname, jstring jvalue) {
void *handle = (void *) 0 ;
char *name = (char *) 0 ;
char *value = (char *) 0 ;
handle = *(void **)&jhandle;
name = 0;
if (jname) {
name = (char *)jenv->GetStringUTFChars(jname, 0);
if (!name) return ;
}
value = 0;
if (jvalue) {
value = (char *)jenv->GetStringUTFChars(jvalue, 0);
if (!value) return ;
}
XGBoosterSetParam(handle, (char const *)name, (char const *)value);
if (name) jenv->ReleaseStringUTFChars(jname, (const char *)name);
if (value) jenv->ReleaseStringUTFChars(jvalue, (const char *)value);
}
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterUpdateOneIter
* Signature: (JIJ)V
*/
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdateOneIter
(JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlong jdtrain) {
void *handle = (void *) 0 ;
int iter ;
void *dtrain = (void *) 0 ;
handle = *(void **)&jhandle;
iter = (int)jiter;
dtrain = *(void **)&jdtrain;
XGBoosterUpdateOneIter(handle, iter, dtrain);
}
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterBoostOneIter
* Signature: (JJ[F[F)V
*/
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoostOneIter
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jfloatArray jgrad, jfloatArray jhess) {
void *handle = (void *) 0 ;
void *dtrain = (void *) 0 ;
bst_ulong len ;
handle = *(void **)&jhandle;
dtrain = *(void **)&jdtrain;
jfloat* grad = jenv->GetFloatArrayElements(jgrad, 0);
jfloat* hess = jenv->GetFloatArrayElements(jhess, 0);
len = (bst_ulong)jenv->GetArrayLength(jgrad);
XGBoosterBoostOneIter(handle, dtrain, grad, hess, len);
//release
jenv->ReleaseFloatArrayElements(jgrad, grad, 0);
jenv->ReleaseFloatArrayElements(jhess, hess, 0);
}
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterEvalOneIter
* Signature: (JI[J[Ljava/lang/String;)Ljava/lang/String;
*/
JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEvalOneIter
(JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlongArray jdmats, jobjectArray jevnames) {
jstring jresult = 0 ;
void *handle = (void *) 0 ;
int iter ;
void **dmats = 0;
char **evnames = 0;
bst_ulong len ;
char *result = 0 ;
handle = *(void **)&jhandle;
iter = (int)jiter;
len = (bst_ulong)jenv->GetArrayLength(jdmats);
if(len > 0) {
dmats = new void*[len];
evnames = new char*[len];
}
//put handle from jhandles to chandles
jlong* cjdmats = jenv->GetLongArrayElements(jdmats, 0);
for(bst_ulong i=0; i<len; i++) {
dmats[i] = *(void **)&cjdmats[i];
}
//transfer jObjectArray to char**
for(bst_ulong i=0; i<len; i++) {
jstring jevname = (jstring)jenv->GetObjectArrayElement(jevnames, i);
evnames[i] = (char *)jenv->GetStringUTFChars(jevname, 0);
}
result = (char *)XGBoosterEvalOneIter(handle, iter, dmats, (char const *(*))evnames, len);
if(len > 0) {
delete[] dmats;
//release string chars
for(bst_ulong i=0; i<len; i++) {
jstring jevname = (jstring)jenv->GetObjectArrayElement(jevnames, i);
jenv->ReleaseStringUTFChars(jevname, (const char*)evnames[i]);
}
delete[] evnames;
jenv->ReleaseLongArrayElements(jdmats, cjdmats, 0);
}
if (result) jresult = jenv->NewStringUTF((const char *)result);
return jresult;
}
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterPredict
* Signature: (JJIJ)[F
*/
JNIEXPORT jfloatArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterPredict
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdmat, jint joption_mask, jlong jntree_limit) {
void *handle = (void *) 0 ;
void *dmat = (void *) 0 ;
int option_mask ;
unsigned int ntree_limit ;
bst_ulong len[1];
*len = 0;
float *result = 0 ;
handle = *(void **)&jhandle;
dmat = *(void **)&jdmat;
option_mask = (int)joption_mask;
ntree_limit = (unsigned int)jntree_limit;
result = (float *)XGBoosterPredict(handle, dmat, option_mask, ntree_limit, len);
jsize jlen = (jsize)*len;
jfloatArray jresult = jenv->NewFloatArray(jlen);
jenv->SetFloatArrayRegion(jresult, 0, jlen, (jfloat *)result);
return jresult;
}
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterLoadModel
* Signature: (JLjava/lang/String;)V
*/
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModel
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) {
void *handle = (void *) 0 ;
char *fname = (char *) 0 ;
handle = *(void **)&jhandle;
fname = 0;
if (jfname) {
fname = (char *)jenv->GetStringUTFChars(jfname, 0);
if (!fname) return ;
}
XGBoosterLoadModel(handle,(char const *)fname);
if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname);
}
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterSaveModel
* Signature: (JLjava/lang/String;)V
*/
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSaveModel
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) {
void *handle = (void *) 0 ;
char *fname = (char *) 0 ;
handle = *(void **)&jhandle;
fname = 0;
if (jfname) {
fname = (char *)jenv->GetStringUTFChars(jfname, 0);
if (!fname) return ;
}
XGBoosterSaveModel(handle, (char const *)fname);
if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname);
}
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterLoadModelFromBuffer
* Signature: (JJJ)V
*/
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModelFromBuffer
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jbuf, jlong jlen) {
void *handle = (void *) 0 ;
void *buf = (void *) 0 ;
bst_ulong len ;
handle = *(void **)&jhandle;
buf = *(void **)&jbuf;
len = (bst_ulong)jlen;
XGBoosterLoadModelFromBuffer(handle, (void const *)buf, len);
}
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterGetModelRaw
* Signature: (J)Ljava/lang/String;
*/
JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGetModelRaw
(JNIEnv * jenv, jclass jcls, jlong jhandle) {
jstring jresult = 0 ;
void *handle = (void *) 0 ;
bst_ulong len[1];
*len = 0;
char *result = 0 ;
handle = *(void **)&jhandle;
result = (char *)XGBoosterGetModelRaw(handle, len);
if (result) jresult = jenv->NewStringUTF((const char *)result);
return jresult;
}
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterDumpModel
* Signature: (JLjava/lang/String;I)[Ljava/lang/String;
*/
JNIEXPORT jobjectArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterDumpModel
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfmap, jint jwith_stats) {
void *handle = (void *) 0 ;
char *fmap = (char *) 0 ;
int with_stats ;
bst_ulong len[1];
*len = 0;
char **result = 0 ;
handle = *(void **)&jhandle;
fmap = 0;
if (jfmap) {
fmap = (char *)jenv->GetStringUTFChars(jfmap, 0);
if (!fmap) return 0;
}
with_stats = (int)jwith_stats;
result = (char **)XGBoosterDumpModel(handle, (char const *)fmap, with_stats, len);
jsize jlen = (jsize)*len;
jobjectArray jresult = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF(""));
for(int i=0 ; i<jlen; i++) {
jenv->SetObjectArrayElement(jresult, i, jenv->NewStringUTF((const char*)result[i]));
}
if (fmap) jenv->ReleaseStringUTFChars(jfmap, (const char *)fmap);
return jresult;
}

213
java/xgboost4j_wrapper.h Normal file
View File

@ -0,0 +1,213 @@
/* DO NOT EDIT THIS FILE - it is machine generated */
#include <jni.h>
/* Header for class org_dmlc_xgboost4j_wrapper_XgboostJNI */
#ifndef _Included_org_dmlc_xgboost4j_wrapper_XgboostJNI
#define _Included_org_dmlc_xgboost4j_wrapper_XgboostJNI
#ifdef __cplusplus
extern "C" {
#endif
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixCreateFromFile
* Signature: (Ljava/lang/String;I)J
*/
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromFile
(JNIEnv *, jclass, jstring, jint);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixCreateFromCSR
* Signature: ([J[J[F)J
*/
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSR
(JNIEnv *, jclass, jlongArray, jintArray, jfloatArray);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixCreateFromCSC
* Signature: ([J[J[F)J
*/
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSC
(JNIEnv *, jclass, jlongArray, jintArray, jfloatArray);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixCreateFromMat
* Signature: ([FIIF)J
*/
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromMat
(JNIEnv *, jclass, jfloatArray, jint, jint, jfloat);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixSliceDMatrix
* Signature: (J[I)J
*/
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSliceDMatrix
(JNIEnv *, jclass, jlong, jintArray);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixFree
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixFree
(JNIEnv *, jclass, jlong);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixSaveBinary
* Signature: (JLjava/lang/String;I)V
*/
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSaveBinary
(JNIEnv *, jclass, jlong, jstring, jint);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixSetFloatInfo
* Signature: (JLjava/lang/String;[F)V
*/
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetFloatInfo
(JNIEnv *, jclass, jlong, jstring, jfloatArray);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixSetUIntInfo
* Signature: (JLjava/lang/String;[I)V
*/
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetUIntInfo
(JNIEnv *, jclass, jlong, jstring, jintArray);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixSetGroup
* Signature: (J[I)V
*/
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGroup
(JNIEnv *, jclass, jlong, jintArray);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixGetFloatInfo
* Signature: (JLjava/lang/String;)[F
*/
JNIEXPORT jfloatArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetFloatInfo
(JNIEnv *, jclass, jlong, jstring);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixGetUIntInfo
* Signature: (JLjava/lang/String;)[I
*/
JNIEXPORT jintArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetUIntInfo
(JNIEnv *, jclass, jlong, jstring);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixNumRow
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixNumRow
(JNIEnv *, jclass, jlong);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterCreate
* Signature: ([J)J
*/
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCreate
(JNIEnv *, jclass, jlongArray);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterFree
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterFree
(JNIEnv *, jclass, jlong);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterSetParam
* Signature: (JLjava/lang/String;Ljava/lang/String;)V
*/
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSetParam
(JNIEnv *, jclass, jlong, jstring, jstring);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterUpdateOneIter
* Signature: (JIJ)V
*/
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdateOneIter
(JNIEnv *, jclass, jlong, jint, jlong);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterBoostOneIter
* Signature: (JJ[F[F)V
*/
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoostOneIter
(JNIEnv *, jclass, jlong, jlong, jfloatArray, jfloatArray);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterEvalOneIter
* Signature: (JI[J[Ljava/lang/String;)Ljava/lang/String;
*/
JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEvalOneIter
(JNIEnv *, jclass, jlong, jint, jlongArray, jobjectArray);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterPredict
* Signature: (JJIJ)[F
*/
JNIEXPORT jfloatArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterPredict
(JNIEnv *, jclass, jlong, jlong, jint, jlong);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterLoadModel
* Signature: (JLjava/lang/String;)V
*/
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModel
(JNIEnv *, jclass, jlong, jstring);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterSaveModel
* Signature: (JLjava/lang/String;)V
*/
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSaveModel
(JNIEnv *, jclass, jlong, jstring);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterLoadModelFromBuffer
* Signature: (JJJ)V
*/
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModelFromBuffer
(JNIEnv *, jclass, jlong, jlong, jlong);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterGetModelRaw
* Signature: (J)Ljava/lang/String;
*/
JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGetModelRaw
(JNIEnv *, jclass, jlong);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterDumpModel
* Signature: (JLjava/lang/String;I)[Ljava/lang/String;
*/
JNIEXPORT jobjectArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterDumpModel
(JNIEnv *, jclass, jlong, jstring, jint);
#ifdef __cplusplus
}
#endif
#endif

View File

@ -10,6 +10,8 @@ Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "xgboost_wrapper", "xgboost_
EndProject EndProject
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "rabit", "..\subtree\rabit\windows\rabit\rabit.vcxproj", "{D7B77D06-4F5F-4BD7-B81E-7CC8EBBE684F}" Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "rabit", "..\subtree\rabit\windows\rabit\rabit.vcxproj", "{D7B77D06-4F5F-4BD7-B81E-7CC8EBBE684F}"
EndProject EndProject
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "xgboostjavawrapper", "xgboostjavawrapper\xgboostjavawrapper.vcxproj", "{20A0E4D7-20C7-4EC1-BDF6-0D469CE239AA}"
EndProject
Global Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Win32 = Debug|Win32 Debug|Win32 = Debug|Win32
@ -41,6 +43,14 @@ Global
{D7B77D06-4F5F-4BD7-B81E-7CC8EBBE684F}.Release|Win32.Build.0 = Release|Win32 {D7B77D06-4F5F-4BD7-B81E-7CC8EBBE684F}.Release|Win32.Build.0 = Release|Win32
{D7B77D06-4F5F-4BD7-B81E-7CC8EBBE684F}.Release|x64.ActiveCfg = Release|x64 {D7B77D06-4F5F-4BD7-B81E-7CC8EBBE684F}.Release|x64.ActiveCfg = Release|x64
{D7B77D06-4F5F-4BD7-B81E-7CC8EBBE684F}.Release|x64.Build.0 = Release|x64 {D7B77D06-4F5F-4BD7-B81E-7CC8EBBE684F}.Release|x64.Build.0 = Release|x64
{20A0E4D7-20C7-4EC1-BDF6-0D469CE239AA}.Debug|Win32.ActiveCfg = Debug|Win32
{20A0E4D7-20C7-4EC1-BDF6-0D469CE239AA}.Debug|Win32.Build.0 = Debug|Win32
{20A0E4D7-20C7-4EC1-BDF6-0D469CE239AA}.Debug|x64.ActiveCfg = Debug|x64
{20A0E4D7-20C7-4EC1-BDF6-0D469CE239AA}.Debug|x64.Build.0 = Debug|x64
{20A0E4D7-20C7-4EC1-BDF6-0D469CE239AA}.Release|Win32.ActiveCfg = Release|Win32
{20A0E4D7-20C7-4EC1-BDF6-0D469CE239AA}.Release|Win32.Build.0 = Release|Win32
{20A0E4D7-20C7-4EC1-BDF6-0D469CE239AA}.Release|x64.ActiveCfg = Release|x64
{20A0E4D7-20C7-4EC1-BDF6-0D469CE239AA}.Release|x64.Build.0 = Release|x64
EndGlobalSection EndGlobalSection
GlobalSection(SolutionProperties) = preSolution GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE HideSolutionNode = FALSE

View File

@ -0,0 +1,129 @@
<?xml version="1.0" encoding="utf-8"?>
<Project DefaultTargets="Build" ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup Label="ProjectConfigurations">
<ProjectConfiguration Include="Debug|Win32">
<Configuration>Debug</Configuration>
<Platform>Win32</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Debug|x64">
<Configuration>Debug</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Release|Win32">
<Configuration>Release</Configuration>
<Platform>Win32</Platform>
</ProjectConfiguration>
<ProjectConfiguration Include="Release|x64">
<Configuration>Release</Configuration>
<Platform>x64</Platform>
</ProjectConfiguration>
</ItemGroup>
<ItemGroup>
<ClCompile Include="..\..\java\xgboost4j_wrapper.cpp" />
<ClCompile Include="..\..\src\gbm\gbm.cpp" />
<ClCompile Include="..\..\src\io\dmlc_simple.cpp" />
<ClCompile Include="..\..\src\io\io.cpp" />
<ClCompile Include="..\..\src\tree\updater.cpp" />
<ClCompile Include="..\..\subtree\rabit\src\engine_empty.cc" />
<ClCompile Include="..\..\wrapper\xgboost_wrapper.cpp" />
</ItemGroup>
<PropertyGroup Label="Globals">
<ProjectGuid>{20A0E4D7-20C7-4EC1-BDF6-0D469CE239AA}</ProjectGuid>
<RootNamespace>xgboost_wrapper</RootNamespace>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'" Label="Configuration">
<ConfigurationType>DynamicLibrary</ConfigurationType>
<UseDebugLibraries>true</UseDebugLibraries>
<CharacterSet>MultiByte</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
<ConfigurationType>DynamicLibrary</ConfigurationType>
<UseDebugLibraries>true</UseDebugLibraries>
<CharacterSet>MultiByte</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'" Label="Configuration">
<ConfigurationType>DynamicLibrary</ConfigurationType>
<UseDebugLibraries>false</UseDebugLibraries>
<WholeProgramOptimization>true</WholeProgramOptimization>
<CharacterSet>MultiByte</CharacterSet>
</PropertyGroup>
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
<ConfigurationType>DynamicLibrary</ConfigurationType>
<UseDebugLibraries>false</UseDebugLibraries>
<WholeProgramOptimization>true</WholeProgramOptimization>
<CharacterSet>MultiByte</CharacterSet>
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
<ImportGroup Label="ExtensionSettings">
</ImportGroup>
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="PropertySheets">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<ImportGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="PropertySheets">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
</ImportGroup>
<PropertyGroup Label="UserMacros" />
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<OutDir>$(SolutionDir)$(Platform)\$(Configuration)\</OutDir>
</PropertyGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<ClCompile>
<WarningLevel>Level3</WarningLevel>
<Optimization>Disabled</Optimization>
</ClCompile>
<Link>
<GenerateDebugInformation>true</GenerateDebugInformation>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<ClCompile>
<WarningLevel>Level3</WarningLevel>
<Optimization>Disabled</Optimization>
</ClCompile>
<Link>
<GenerateDebugInformation>true</GenerateDebugInformation>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
<ClCompile>
<WarningLevel>Level3</WarningLevel>
<Optimization>MaxSpeed</Optimization>
<FunctionLevelLinking>true</FunctionLevelLinking>
<IntrinsicFunctions>true</IntrinsicFunctions>
<OpenMPSupport>true</OpenMPSupport>
<AdditionalIncludeDirectories>$(JAVA_HOME)\include;$(JAVA_HOME)\include\win32;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
</ClCompile>
<Link>
<GenerateDebugInformation>true</GenerateDebugInformation>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<ClCompile>
<WarningLevel>Level3</WarningLevel>
<Optimization>MaxSpeed</Optimization>
<FunctionLevelLinking>true</FunctionLevelLinking>
<IntrinsicFunctions>true</IntrinsicFunctions>
<OpenMPSupport>true</OpenMPSupport>
<RuntimeLibrary>MultiThreaded</RuntimeLibrary>
<AdditionalIncludeDirectories>$(JAVA_HOME)\include\win32;$(JAVA_HOME)\include;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
</ClCompile>
<Link>
<GenerateDebugInformation>true</GenerateDebugInformation>
<EnableCOMDATFolding>true</EnableCOMDATFolding>
<OptimizeReferences>true</OptimizeReferences>
<AdditionalDependencies>ws2_32.lib;%(AdditionalDependencies)</AdditionalDependencies>
</Link>
</ItemDefinitionGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
</ImportGroup>
</Project>