diff --git a/.gitignore b/.gitignore index c38e16aed..44a215435 100644 --- a/.gitignore +++ b/.gitignore @@ -58,3 +58,11 @@ R-package.Rproj *.cache* R-package/inst R-package/src +#java +java/xgboost4j/target +java/xgboost4j/tmp +java/xgboost4j-demo/target +java/xgboost4j-demo/data/ +java/xgboost4j-demo/tmp/ +java/xgboost4j-demo/model/ +nb-configuration* diff --git a/Makefile b/Makefile index e568222c2..360d55e84 100644 --- a/Makefile +++ b/Makefile @@ -3,6 +3,8 @@ export CXX = g++ export MPICXX = mpicxx export LDFLAGS= -pthread -lm export CFLAGS = -Wall -O3 -msse2 -Wno-unknown-pragmas -funroll-loops +# java include path +export JAVAINCFLAGS = -I${JAVA_HOME}/include -I${JAVA_HOME}/include/linux -I./java ifeq ($(OS), Windows_NT) export CXX = g++ -m64 @@ -53,6 +55,9 @@ else SLIB = wrapper/libxgboostwrapper.so endif +# java lib +JLIB = java/libxgboostjavawrapper.so + # specify tensor path BIN = xgboost MOCKBIN = xgboost.mock @@ -79,6 +84,9 @@ main.o: src/xgboost_main.cpp src/utils/*.h src/*.h src/learner/*.hpp src/learner xgboost: updater.o gbm.o io.o main.o $(LIBRABIT) $(LIBDMLC) wrapper/xgboost_wrapper.dll wrapper/libxgboostwrapper.so: wrapper/xgboost_wrapper.cpp src/utils/*.h src/*.h src/learner/*.hpp src/learner/*.h updater.o gbm.o io.o $(LIBRABIT) $(LIBDMLC) +java: java/libxgboostjavawrapper.so +java/libxgboostjavawrapper.so: java/xgboost4j_wrapper.cpp wrapper/xgboost_wrapper.cpp src/utils/*.h src/*.h src/learner/*.hpp src/learner/*.h updater.o gbm.o io.o $(LIBRABIT) $(LIBDMLC) + # dependency on rabit subtree/rabit/lib/librabit.a: subtree/rabit/src/engine.cc + cd subtree/rabit;make lib/librabit.a; cd ../.. @@ -98,6 +106,9 @@ $(MOCKBIN) : $(SLIB) : $(CXX) $(CFLAGS) -fPIC -shared -o $@ $(filter %.cpp %.o %.c %.a %.cc, $^) $(LDFLAGS) $(DLLFLAGS) +$(JLIB) : + $(CXX) $(CFLAGS) -fPIC -shared -o $@ $(filter %.cpp %.o %.c %.a %.cc, $^) $(LDFLAGS) $(JAVAINCFLAGS) + $(OBJ) : $(CXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c %.cc, $^) ) @@ -105,7 +116,7 @@ $(MPIOBJ) : $(MPICXX) -c $(CFLAGS) -o $@ $(firstword $(filter %.cpp %.c, $^) ) $(MPIBIN) : - $(MPICXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^) $(LDFLAGS) + $(MPICXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.cc %.a, $^) $(LDFLAGS) install: cp -f -r $(BIN) $(INSTALL_PATH) diff --git a/java/README.md b/java/README.md new file mode 100644 index 000000000..12cbb4582 --- /dev/null +++ b/java/README.md @@ -0,0 +1,28 @@ +# xgboost4j +this is a java wrapper for xgboost + +the structure of this wrapper is almost the same as the official python wrapper. + +core of this wrapper is two classes: + +* DMatrix: for handling data + +* Booster: for train and predict + +## usage: + please refer to [xgboost4j.md](doc/xgboost4j.md) for more information. + + besides, simple examples could be found in [xgboost4j-demo](xgboost4j-demo/README.md) + + +## build native library + +for windows: open the xgboost.sln in "../windows" folder, you will found the xgboostjavawrapper project, you should do the following steps to build wrapper library: + * Select x64/win32 and Release in build + * (if you have setted `JAVA_HOME` properly in windows environment variables, escape this step) right click on xgboostjavawrapper project -> choose "Properties" -> click on "C/C++" in the window -> change the "Additional Include Directories" to fit your jdk install path. + * rebuild all + * double click "create_wrap.bat" to set library to proper place + +for linux: + * make sure you have installed jdk and `JAVA_HOME` has been setted properly + * run "create_wrap.sh" diff --git a/java/create_wrap.bat b/java/create_wrap.bat new file mode 100644 index 000000000..e7f8603cd --- /dev/null +++ b/java/create_wrap.bat @@ -0,0 +1,20 @@ +echo "move native library" +set libsource=..\windows\x64\Release\xgboostjavawrapper.dll + +if not exist %libsource% ( +goto end +) + +set libfolder=xgboost4j\src\main\resources\lib +set libpath=%libfolder%\xgboostjavawrapper.dll +if not exist %libfolder% (mkdir %libfolder%) +if exist %libpath% (del %libpath%) +move %libsource% %libfolder% +echo complete +pause +exit + +:end + echo "source library not found, please build it first from ..\windows\xgboost.sln" + pause + exit \ No newline at end of file diff --git a/java/create_wrap.sh b/java/create_wrap.sh new file mode 100755 index 000000000..d66e4dbd4 --- /dev/null +++ b/java/create_wrap.sh @@ -0,0 +1,15 @@ +echo "build java wrapper" +cd .. +make java +cd java +echo "move native lib" + +libPath="xgboost4j/src/main/resources/lib" +if [ ! -d "$libPath" ]; then + mkdir -p "$libPath" +fi + +rm -f xgboost4j/src/main/resources/lib/libxgboostjavawrapper.so +mv libxgboostjavawrapper.so xgboost4j/src/main/resources/lib/ + +echo "complete" diff --git a/java/doc/xgboost4j.md b/java/doc/xgboost4j.md new file mode 100644 index 000000000..201b3cc05 --- /dev/null +++ b/java/doc/xgboost4j.md @@ -0,0 +1,156 @@ +xgboost4j : java wrapper for xgboost +==== + +This page will introduce xgboost4j, the java wrapper for xgboost, including: +* [Building](#build-xgboost4j) +* [Data Interface](#data-interface) +* [Setting Parameters](#setting-parameters) +* [Train Model](#training-model) +* [Prediction](#prediction) + += +#### Build xgboost4j +* Build native library +first make sure you have installed jdk and `JAVA_HOME` has been setted properly, then simply run `./create_wrap.sh`. + +* Package xgboost4j +to package xgboost4j, you can run `mvn package` in xgboost4j folder or just use IDE(eclipse/netbeans) to open this maven project and build. + += +#### Data Interface +Like the xgboost python module, xgboost4j use ```DMatrix``` to handle data, libsvm txt format file, sparse matrix in CSR/CSC format, and dense matrix is supported. + +* To import ```DMatrix``` : +```java +import org.dmlc.xgboost4j.DMatrix; +``` + +* To load libsvm text format file, the usage is like : +```java +DMatrix dmat = new DMatrix("train.svm.txt"); +``` + +* To load sparse matrix in CSR/CSC format is a little complicated, the usage is like : +suppose a sparse matrix : +1 0 2 0 +4 0 0 3 +3 1 2 0 + + for CSR format +```java +long[] rowHeaders = new long[] {0,2,4,7}; +float[] data = new float[] {1f,2f,4f,3f,3f,1f,2f}; +int[] colIndex = new int[] {0,2,0,3,0,1,2}; +DMatrix dmat = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR); +``` + + for CSC format +```java +long[] colHeaders = new long[] {0,3,4,6,7}; +float[] data = new float[] {1f,4f,3f,1f,2f,2f,3f}; +int[] rowIndex = new int[] {0,1,2,2,0,2,1}; +DMatrix dmat = new DMatrix(colHeaders, rowIndex, data, DMatrix.SparseType.CSC); +``` + +* To load 3*2 dense matrix, the usage is like : +suppose a matrix : +1 2 +3 4 +5 6 + +```java +float[] data = new float[] {1f,2f,3f,4f,5f,6f}; +int nrow = 3; +int ncol = 2; +float missing = 0.0f; +DMatrix dmat = new Matrix(data, nrow, ncol, missing); +``` + +* To set weight : +```java +float[] weights = new float[] {1f,2f,1f}; +dmat.setWeight(weights); +``` + +#### Setting Parameters +* in xgboost4j any ```Iterable>``` object could be used as parameters. + +* to set parameters, for non-multiple value params, you can simply use entrySet of an Map: +```java +Map paramMap = new HashMap<>() { + { + put("eta", 1.0); + put("max_depth", 2); + put("silent", 1); + put("objective", "binary:logistic"); + put("eval_metric", "logloss"); + } +}; +Iterable> params = paramMap.entrySet(); +``` +* for the situation that multiple values with same param key, List> would be a good choice, e.g. : +```java +List> params = new ArrayList>() { + { + add(new SimpleEntry("eta", 1.0)); + add(new SimpleEntry("max_depth", 2.0)); + add(new SimpleEntry("silent", 1)); + add(new SimpleEntry("objective", "binary:logistic")); + } +}; +``` + +#### Training Model +With parameters and data, you are able to train a booster model. +* Import ```Trainer``` and ```Booster``` : +```java +import org.dmlc.xgboost4j.Booster; +import org.dmlc.xgboost4j.util.Trainer; +``` + +* Training +```java +DMatrix trainMat = new DMatrix("train.svm.txt"); +DMatrix validMat = new DMatrix("valid.svm.txt"); +//specifiy a watchList to see the performance +//any Iterable> object could be used as watchList +List> watchs = new ArrayList<>(); +watchs.add(new SimpleEntry<>("train", trainMat)); +watchs.add(new SimpleEntry<>("test", testMat)); +int round = 2; +Booster booster = Trainer.train(params, trainMat, round, watchs, null, null); +``` + +* Saving model +After training, you can save model and dump it out. +```java +booster.saveModel("model.bin"); +``` + +* Dump Model and Feature Map +```java +booster.dumpModel("modelInfo.txt", false) +//dump with featureMap +booster.dumpModel("modelInfo.txt", "featureMap.txt", false) +``` + +* Load a model +```java +Params param = new Params() { + { + put("silent", 1); + put("nthread", 6); + } +}; +Booster booster = new Booster(param, "model.bin"); +``` + +####Prediction +after training and loading a model, you use it to predict other data, the predict results will be a two-dimension float array (nsample, nclass) ,for predict leaf, it would be (nsample, nclass*ntrees) +```java +DMatrix dtest = new DMatrix("test.svm.txt"); +//predict +float[][] predicts = booster.predict(dtest); +//predict leaf +float[][] leafPredicts = booster.predict(dtest, 0, true); +``` diff --git a/java/xgboost4j-demo/LICENSE b/java/xgboost4j-demo/LICENSE new file mode 100644 index 000000000..9a1673be2 --- /dev/null +++ b/java/xgboost4j-demo/LICENSE @@ -0,0 +1,15 @@ +/* +Copyright (c) 2014 by Contributors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ \ No newline at end of file diff --git a/java/xgboost4j-demo/README.md b/java/xgboost4j-demo/README.md new file mode 100644 index 000000000..c9cb35e4b --- /dev/null +++ b/java/xgboost4j-demo/README.md @@ -0,0 +1,10 @@ +xgboost4j examples +==== +* [Basic walkthrough of wrappers](src/main/java/org/dmlc/xgboost4j/demo/BasicWalkThrough.java) +* [Cutomize loss function, and evaluation metric](src/main/java/org/dmlc/xgboost4j/demo/CustomObjective.java) +* [Boosting from existing prediction](src/main/java/org/dmlc/xgboost4j/demo/BoostFromPrediction.java) +* [Predicting using first n trees](src/main/java/org/dmlc/xgboost4j/demo/PredictFirstNtree.java) +* [Generalized Linear Model](src/main/java/org/dmlc/xgboost4j/demo/GeneralizedLinearModel.java) +* [Cross validation](src/main/java/org/dmlc/xgboost4j/demo/CrossValidation.java) +* [Predicting leaf indices](src/main/java/org/dmlc/xgboost4j/demo/PredictLeafIndices.java) +* [External Memory](src/main/java/org/dmlc/xgboost4j/demo/ExternalMemory.java) diff --git a/java/xgboost4j-demo/pom.xml b/java/xgboost4j-demo/pom.xml new file mode 100644 index 000000000..28c51bc13 --- /dev/null +++ b/java/xgboost4j-demo/pom.xml @@ -0,0 +1,36 @@ + + + 4.0.0 + org.dmlc + xgboost4j-demo + 1.0 + jar + + UTF-8 + 1.7 + 1.7 + + + + org.dmlc + xgboost4j + 1.1 + + + commons-io + commons-io + 2.4 + + + org.apache.commons + commons-lang3 + 3.4 + + + junit + junit + 4.11 + test + + + \ No newline at end of file diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BasicWalkThrough.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BasicWalkThrough.java new file mode 100644 index 000000000..a0c7a3ae1 --- /dev/null +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BasicWalkThrough.java @@ -0,0 +1,163 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package org.dmlc.xgboost4j.demo; + +import java.io.File; +import java.io.IOException; +import java.io.UnsupportedEncodingException; +import java.util.AbstractMap; +import java.util.AbstractMap.SimpleEntry; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import org.dmlc.xgboost4j.Booster; +import org.dmlc.xgboost4j.DMatrix; +import org.dmlc.xgboost4j.demo.util.DataLoader; +import org.dmlc.xgboost4j.demo.util.Params; +import org.dmlc.xgboost4j.util.Trainer; + +/** + * a simple example of java wrapper for xgboost + * @author hzx + */ +public class BasicWalkThrough { + public static boolean checkPredicts(float[][] fPredicts, float[][] sPredicts) { + if(fPredicts.length != sPredicts.length) { + return false; + } + + for(int i=0; i> object would be used as paramters + //e.g. + // Map paramMap = new HashMap() { + // { + // put("eta", 1.0); + // put("max_depth", 2); + // put("silent", 1); + // put("objective", "binary:logistic"); + // } + // }; + // Iterable> param = paramMap.entrySet(); + + //or + // List> param = new ArrayList>() { + // { + // add(new SimpleEntry("eta", 1.0)); + // add(new SimpleEntry("max_depth", 2.0)); + // add(new SimpleEntry("silent", 1)); + // add(new SimpleEntry("objective", "binary:logistic")); + // } + // }; + + //we use a util class Params to handle parameters as example + Iterable> param = new Params() { + { + put("eta", 1.0); + put("max_depth", 2); + put("silent", 1); + put("objective", "binary:logistic"); + } + }; + + + + //specify watchList to set evaluation dmats + //note: any Iterable> object would be used as watchList + //e.g. + //an entrySet of Map is good + // Map watchMap = new HashMap<>(); + // watchMap.put("train", trainMat); + // watchMap.put("test", testMat); + // Iterable> watchs = watchMap.entrySet(); + + //we use a List of Entry WatchList as example + List> watchs = new ArrayList<>(); + watchs.add(new SimpleEntry<>("train", trainMat)); + watchs.add(new SimpleEntry<>("test", testMat)); + + //set round + int round = 2; + + //train a boost model + Booster booster = Trainer.train(param, trainMat, round, watchs, null, null); + + //predict + float[][] predicts = booster.predict(testMat); + + //save model to modelPath + File file = new File("./model"); + if(!file.exists()) { + file.mkdirs(); + } + + String modelPath = "./model/xgb.model"; + booster.saveModel(modelPath); + + //dump model + booster.dumpModel("./model/dump.raw.txt", false); + + //dump model with feature map + booster.dumpModel("./model/dump.nice.txt", "../../demo/data/featmap.txt", false); + + //save dmatrix into binary buffer + testMat.saveBinary("./model/dtest.buffer"); + + //reload model and data + Booster booster2 = new Booster(param, "./model/xgb.model"); + DMatrix testMat2 = new DMatrix("./model/dtest.buffer"); + float[][] predicts2 = booster2.predict(testMat2); + + + //check the two predicts + System.out.println(checkPredicts(predicts, predicts2)); + + System.out.println("start build dmatrix from csr sparse data ..."); + //build dmatrix from CSR Sparse Matrix + DataLoader.CSRSparseData spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train"); + + DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data, DMatrix.SparseType.CSR); + trainMat2.setLabel(spData.labels); + + //specify watchList + List> watchs2 = new ArrayList<>(); + watchs2.add(new SimpleEntry<>("train", trainMat2)); + watchs2.add(new SimpleEntry<>("test", testMat2)); + Booster booster3 = Trainer.train(param, trainMat2, round, watchs2, null, null); + float[][] predicts3 = booster3.predict(testMat2); + + //check predicts + System.out.println(checkPredicts(predicts, predicts3)); + } +} diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BoostFromPrediction.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BoostFromPrediction.java new file mode 100644 index 000000000..733c49503 --- /dev/null +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BoostFromPrediction.java @@ -0,0 +1,66 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package org.dmlc.xgboost4j.demo; + +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import org.dmlc.xgboost4j.Booster; +import org.dmlc.xgboost4j.DMatrix; +import org.dmlc.xgboost4j.demo.util.Params; +import org.dmlc.xgboost4j.util.Trainer; + +/** + * example for start from a initial base prediction + * @author hzx + */ +public class BoostFromPrediction { + public static void main(String[] args) { + System.out.println("start running example to start from a initial prediction"); + + // load file from text file, also binary buffer generated by xgboost4j + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + + //specify parameters + Params param = new Params() { + { + put("eta", 1.0); + put("max_depth", 2); + put("silent", 1); + put("objective", "binary:logistic"); + } + }; + + //specify watchList + List> watchs = new ArrayList<>(); + watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat)); + watchs.add(new AbstractMap.SimpleEntry<>("test", testMat)); + + //train xgboost for 1 round + Booster booster = Trainer.train(param, trainMat, 1, watchs, null, null); + + float[][] trainPred = booster.predict(trainMat, true); + float[][] testPred = booster.predict(testMat, true); + + trainMat.setBaseMargin(trainPred); + testMat.setBaseMargin(testPred); + + System.out.println("result of running from initial prediction"); + Booster booster2 = Trainer.train(param, trainMat, 1, watchs, null, null); + } +} diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CrossValidation.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CrossValidation.java new file mode 100644 index 000000000..0c470bf17 --- /dev/null +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CrossValidation.java @@ -0,0 +1,53 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package org.dmlc.xgboost4j.demo; + +import java.io.IOException; +import org.dmlc.xgboost4j.DMatrix; +import org.dmlc.xgboost4j.util.Trainer; +import org.dmlc.xgboost4j.demo.util.Params; + +/** + * an example of cross validation + * @author hzx + */ +public class CrossValidation { + public static void main(String[] args) throws IOException { + //load train mat + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + + //set params + Params param = new Params() { + { + put("eta", 1.0); + put("max_depth", 3); + put("silent", 1); + put("nthread", 6); + put("objective", "binary:logistic"); + put("gamma", 1.0); + put("eval_metric", "error"); + } + }; + + //do 5-fold cross validation + int round = 2; + int nfold = 5; + //set additional eval_metrics + String[] metrics = null; + + String[] evalHist = Trainer.crossValiation(param, trainMat, round, nfold, metrics, null, null); + } +} diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CustomObjective.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CustomObjective.java new file mode 100644 index 000000000..03c9c4b52 --- /dev/null +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CustomObjective.java @@ -0,0 +1,156 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package org.dmlc.xgboost4j.demo; + +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import org.dmlc.xgboost4j.Booster; +import org.dmlc.xgboost4j.IEvaluation; +import org.dmlc.xgboost4j.DMatrix; +import org.dmlc.xgboost4j.IObjective; +import org.dmlc.xgboost4j.demo.util.Params; +import org.dmlc.xgboost4j.util.Trainer; + +/** + * an example user define objective and eval + * NOTE: when you do customized loss function, the default prediction value is margin + * this may make buildin evalution metric not function properly + * for example, we are doing logistic loss, the prediction is score before logistic transformation + * he buildin evaluation error assumes input is after logistic transformation + * Take this in mind when you use the customization, and maybe you need write customized evaluation function + * @author hzx + */ +public class CustomObjective { + /** + * loglikelihoode loss obj function + */ + public static class LogRegObj implements IObjective { + /** + * simple sigmoid func + * @param input + * @return + * Note: this func is not concern about numerical stability, only used as example + */ + public float sigmoid(float input) { + float val = (float) (1/(1+Math.exp(-input))); + return val; + } + + public float[][] transform(float[][] predicts) { + int nrow = predicts.length; + float[][] transPredicts = new float[nrow][1]; + + for(int i=0; i getGradient(float[][] predicts, DMatrix dtrain) { + int nrow = predicts.length; + List gradients = new ArrayList<>(); + float[] labels = dtrain.getLabel(); + float[] grad = new float[nrow]; + float[] hess = new float[nrow]; + + float[][] transPredicts = transform(predicts); + + for(int i=0; i0) { + error++; + } + else if(labels[i]==1f && predicts[i][0]<=0) { + error++; + } + } + + return error/labels.length; + } + } + + public static void main(String[] args) { + //load train mat (svmlight format) + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + //load valid mat (svmlight format) + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + + //set params + //set params + Params param = new Params() { + { + put("eta", 1.0); + put("max_depth", 2); + put("silent", 1); + } + }; + + //set round + int round = 2; + + //specify watchList + List> watchs = new ArrayList<>(); + watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat)); + watchs.add(new AbstractMap.SimpleEntry<>("test", testMat)); + + //user define obj and eval + IObjective obj = new LogRegObj(); + IEvaluation eval = new EvalError(); + + //train a booster + System.out.println("begin to train the booster model"); + Booster booster = Trainer.train(param, trainMat, round, watchs, obj, eval); + } +} diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/ExternalMemory.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/ExternalMemory.java new file mode 100644 index 000000000..6ac687289 --- /dev/null +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/ExternalMemory.java @@ -0,0 +1,64 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package org.dmlc.xgboost4j.demo; + +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import org.dmlc.xgboost4j.Booster; +import org.dmlc.xgboost4j.DMatrix; +import org.dmlc.xgboost4j.demo.util.Params; +import org.dmlc.xgboost4j.util.Trainer; + +/** + * simple example for using external memory version + * @author hzx + */ +public class ExternalMemory { + public static void main(String[] args) { + //this is the only difference, add a # followed by a cache prefix name + //several cache file with the prefix will be generated + //currently only support convert from libsvm file + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache"); + + //specify parameters + Params param = new Params() { + { + put("eta", 1.0); + put("max_depth", 2); + put("silent", 1); + put("objective", "binary:logistic"); + } + }; + + //performance notice: set nthread to be the number of your real cpu + //some cpu offer two threads per core, for example, a 4 core cpu with 8 threads, in such case set nthread=4 + //param.put("nthread", num_real_cpu); + + //specify watchList + List> watchs = new ArrayList<>(); + watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat)); + watchs.add(new AbstractMap.SimpleEntry<>("test", testMat)); + + //set round + int round = 2; + + //train a boost model + Booster booster = Trainer.train(param, trainMat, round, watchs, null, null); + } +} diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/GeneralizedLinearModel.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/GeneralizedLinearModel.java new file mode 100644 index 000000000..2a20edbff --- /dev/null +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/GeneralizedLinearModel.java @@ -0,0 +1,73 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package org.dmlc.xgboost4j.demo; + +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import org.dmlc.xgboost4j.Booster; +import org.dmlc.xgboost4j.DMatrix; +import org.dmlc.xgboost4j.demo.util.CustomEval; +import org.dmlc.xgboost4j.demo.util.Params; +import org.dmlc.xgboost4j.util.Trainer; + +/** + * this is an example of fit generalized linear model in xgboost + * basically, we are using linear model, instead of tree for our boosters + * @author hzx + */ +public class GeneralizedLinearModel { + public static void main(String[] args) { + // load file from text file, also binary buffer generated by xgboost4j + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + + //specify parameters + //change booster to gblinear, so that we are fitting a linear model + // alpha is the L1 regularizer + //lambda is the L2 regularizer + //you can also set lambda_bias which is L2 regularizer on the bias term + Params param = new Params() { + { + put("alpha", 0.0001); + put("silent", 1); + put("objective", "binary:logistic"); + put("booster", "gblinear"); + } + }; + //normally, you do not need to set eta (step_size) + //XGBoost uses a parallel coordinate descent algorithm (shotgun), + //there could be affection on convergence with parallelization on certain cases + //setting eta to be smaller value, e.g 0.5 can make the optimization more stable + //param.put("eta", "0.5"); + + + //specify watchList + List> watchs = new ArrayList<>(); + watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat)); + watchs.add(new AbstractMap.SimpleEntry<>("test", testMat)); + + //train a booster + int round = 4; + Booster booster = Trainer.train(param, trainMat, round, watchs, null, null); + + float[][] predicts = booster.predict(testMat); + + CustomEval eval = new CustomEval(); + System.out.println("error=" + eval.eval(predicts, testMat)); + } +} diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictFirstNtree.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictFirstNtree.java new file mode 100644 index 000000000..8e3f3abfb --- /dev/null +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictFirstNtree.java @@ -0,0 +1,68 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package org.dmlc.xgboost4j.demo; + +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import org.dmlc.xgboost4j.Booster; +import org.dmlc.xgboost4j.DMatrix; +import org.dmlc.xgboost4j.util.Trainer; + +import org.dmlc.xgboost4j.demo.util.CustomEval; +import org.dmlc.xgboost4j.demo.util.Params; + +/** + * predict first ntree + * @author hzx + */ +public class PredictFirstNtree { + public static void main(String[] args) { + // load file from text file, also binary buffer generated by xgboost4j + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + + //specify parameters + Params param = new Params() { + { + put("eta", 1.0); + put("max_depth", 2); + put("silent", 1); + put("objective", "binary:logistic"); + } + }; + + //specify watchList + List> watchs = new ArrayList<>(); + watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat)); + watchs.add(new AbstractMap.SimpleEntry<>("test", testMat)); + + //train a booster + int round = 3; + Booster booster = Trainer.train(param, trainMat, round, watchs, null, null); + + //predict use 1 tree + float[][] predicts1 = booster.predict(testMat, false, 1); + //by default all trees are used to do predict + float[][] predicts2 = booster.predict(testMat); + + //use a simple evaluation class to check error result + CustomEval eval = new CustomEval(); + System.out.println("error of predicts1: " + eval.eval(predicts1, testMat)); + System.out.println("error of predicts2: " + eval.eval(predicts2, testMat)); + } +} diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictLeafIndices.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictLeafIndices.java new file mode 100644 index 000000000..697f40379 --- /dev/null +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictLeafIndices.java @@ -0,0 +1,69 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package org.dmlc.xgboost4j.demo; + +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import org.dmlc.xgboost4j.Booster; +import org.dmlc.xgboost4j.DMatrix; +import org.dmlc.xgboost4j.util.Trainer; +import org.dmlc.xgboost4j.demo.util.Params; + +/** + * predict leaf indices + * @author hzx + */ +public class PredictLeafIndices { + public static void main(String[] args) { + // load file from text file, also binary buffer generated by xgboost4j + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + + //specify parameters + Params param = new Params() { + { + put("eta", 1.0); + put("max_depth", 2); + put("silent", 1); + put("objective", "binary:logistic"); + } + }; + + //specify watchList + List> watchs = new ArrayList<>(); + watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat)); + watchs.add(new AbstractMap.SimpleEntry<>("test", testMat)); + + //train a booster + int round = 3; + Booster booster = Trainer.train(param, trainMat, round, watchs, null, null); + + //predict using first 2 tree + float[][] leafindex = booster.predict(testMat, 2, true); + for(float[] leafs : leafindex) { + System.out.println(Arrays.toString(leafs)); + } + + //predict all trees + leafindex = booster.predict(testMat, 0, true); + for(float[] leafs : leafindex) { + System.out.println(Arrays.toString(leafs)); + } + } +} diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/CustomEval.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/CustomEval.java new file mode 100644 index 000000000..ad3a9124b --- /dev/null +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/CustomEval.java @@ -0,0 +1,50 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package org.dmlc.xgboost4j.demo.util; + +import org.dmlc.xgboost4j.DMatrix; +import org.dmlc.xgboost4j.IEvaluation; + +/** + * a util evaluation class for examples + * @author hzx + */ +public class CustomEval implements IEvaluation { + + String evalMetric = "custom_error"; + + @Override + public String getMetric() { + return evalMetric; + } + + @Override + public float eval(float[][] predicts, DMatrix dmat) { + float error = 0f; + float[] labels = dmat.getLabel(); + int nrow = predicts.length; + for(int i=0; i0.5) { + error++; + } + else if(labels[i]==1f && predicts[i][0]<=0.5) { + error++; + } + } + + return error/labels.length; + } +} diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/DataLoader.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/DataLoader.java new file mode 100644 index 000000000..0a020c761 --- /dev/null +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/DataLoader.java @@ -0,0 +1,129 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package org.dmlc.xgboost4j.demo.util; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.UnsupportedEncodingException; +import java.util.ArrayList; +import java.util.List; +import org.apache.commons.lang3.ArrayUtils; + +/** + * util class for loading data + * @author hzx + */ +public class DataLoader { + public static class DenseData { + public float[] labels; + public float[] data; + public int nrow; + public int ncol; + } + + public static class CSRSparseData { + public float[] labels; + public float[] data; + public long[] rowHeaders; + public int[] colIndex; + } + + public static DenseData loadCSVFile(String filePath) throws FileNotFoundException, UnsupportedEncodingException, IOException { + DenseData denseData = new DenseData(); + + File f = new File(filePath); + FileInputStream in = new FileInputStream(f); + BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8")); + + denseData.nrow = 0; + denseData.ncol = -1; + String line; + List tlabels = new ArrayList<>(); + List tdata = new ArrayList<>(); + + while((line=reader.readLine()) != null) { + String[] items = line.trim().split(","); + if(items.length==0) { + continue; + } + denseData.nrow++; + if(denseData.ncol == -1) { + denseData.ncol = items.length - 1; + } + + tlabels.add(Float.valueOf(items[items.length-1])); + for(int i=0; i tlabels = new ArrayList<>(); + List tdata = new ArrayList<>(); + List theaders = new ArrayList<>(); + List tindex = new ArrayList<>(); + + File f = new File(filePath); + FileInputStream in = new FileInputStream(f); + BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8")); + + String line; + long rowheader = 0; + theaders.add(rowheader); + while((line=reader.readLine()) != null) { + String[] items = line.trim().split(" "); + if(items.length==0) { + continue; + } + + rowheader += items.length - 1; + theaders.add(rowheader); + tlabels.add(Float.valueOf(items[0])); + + for(int i=1; i>{ + List> params = new ArrayList<>(); + + /** + * put param key-value pair + * @param key + * @param value + */ + public void put(String key, Object value) { + params.add(new AbstractMap.SimpleEntry<>(key, value)); + } + + @Override + public String toString(){ + String paramsInfo = ""; + for(Entry param : params) { + paramsInfo += param.getKey() + ":" + param.getValue() + "\n"; + } + return paramsInfo; + } + + @Override + public Iterator> iterator() { + return params.iterator(); + } +} diff --git a/java/xgboost4j/LICENSE b/java/xgboost4j/LICENSE new file mode 100644 index 000000000..9a1673be2 --- /dev/null +++ b/java/xgboost4j/LICENSE @@ -0,0 +1,15 @@ +/* +Copyright (c) 2014 by Contributors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ \ No newline at end of file diff --git a/java/xgboost4j/README.md b/java/xgboost4j/README.md new file mode 100644 index 000000000..e46a5b3a0 --- /dev/null +++ b/java/xgboost4j/README.md @@ -0,0 +1,23 @@ +# xgboost4j +this is a java wrapper for xgboost (https://github.com/dmlc/xgboost) +the structure of this wrapper is almost the same as the official python wrapper. +core of this wrapper is two classes: + +* DMatrix: for handling data + +* Booster: for train and predict + +## usage: + + simple examples could be found in test package: + + * Simple Train Example: org.dmlc.xgboost4j.TrainExample.java + + * Simple Predict Example: org.dmlc.xgboost4j.PredictExample.java + + * Cross Validation Example: org.dmlc.xgboost4j.example.CVExample.java + +## native library: + + only 64-bit linux/windows is supported now, if you want to build native wrapper library yourself, please refer to + https://github.com/yanqingmen/xgboost-java, and put your native library to the "./src/main/resources/lib" folder and replace the originals. (either "libxgboostjavawrapper.so" for linux or "xgboostjavawrapper.dll" for windows) diff --git a/java/xgboost4j/pom.xml b/java/xgboost4j/pom.xml new file mode 100644 index 000000000..5e312bf4f --- /dev/null +++ b/java/xgboost4j/pom.xml @@ -0,0 +1,35 @@ + + + 4.0.0 + org.dmlc + xgboost4j + 1.1 + jar + + UTF-8 + 1.7 + 1.7 + + + + + org.apache.maven.plugins + maven-javadoc-plugin + 2.10.3 + + + + + + junit + junit + 4.11 + test + + + commons-logging + commons-logging + 1.2 + + + \ No newline at end of file diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/Booster.java b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/Booster.java new file mode 100644 index 000000000..c5d8b1006 --- /dev/null +++ b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/Booster.java @@ -0,0 +1,450 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package org.dmlc.xgboost4j; + +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.io.UnsupportedEncodingException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.dmlc.xgboost4j.util.Initializer; +import org.dmlc.xgboost4j.wrapper.XgboostJNI; + + +/** + * Booster for xgboost, similar to the python wrapper xgboost.py + * but custom obj function and eval function not supported at present. + * @author hzx + */ +public final class Booster { + private static final Log logger = LogFactory.getLog(Booster.class); + + long handle = 0; + + //load native library + static { + try { + Initializer.InitXgboost(); + } catch (IOException ex) { + logger.error("load native library failed."); + logger.error(ex); + } + } + + /** + * init Booster from dMatrixs + * @param params parameters + * @param dMatrixs DMatrix array + */ + public Booster(Iterable> params, DMatrix[] dMatrixs) { + init(dMatrixs); + setParam("seed","0"); + setParams(params); + } + + + + /** + * load model from modelPath + * @param params parameters + * @param modelPath booster modelPath (model generated by booster.saveModel) + */ + public Booster(Iterable> params, String modelPath) { + handle = XgboostJNI.XGBoosterCreate(new long[] {}); + loadModel(modelPath); + setParam("seed","0"); + setParams(params); + } + + + + + private void init(DMatrix[] dMatrixs) { + long[] handles = null; + if(dMatrixs != null) { + handles = dMatrixs2handles(dMatrixs); + } + handle = XgboostJNI.XGBoosterCreate(handles); + } + + /** + * set parameter + * @param key param name + * @param value param value + */ + public final void setParam(String key, String value) { + XgboostJNI.XGBoosterSetParam(handle, key, value); + } + + /** + * set parameters + * @param params parameters key-value map + */ + public void setParams(Iterable> params) { + if(params!=null) { + for(Map.Entry entry : params) { + setParam(entry.getKey(), entry.getValue().toString()); + } + } + } + + + /** + * Update (one iteration) + * @param dtrain training data + * @param iter current iteration number + */ + public void update(DMatrix dtrain, int iter) { + XgboostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()); + } + + /** + * update with customize obj func + * @param dtrain training data + * @param iter current iteration number + * @param obj customized objective class + */ + public void update(DMatrix dtrain, int iter, IObjective obj) { + float[][] predicts = predict(dtrain, true); + List gradients = obj.getGradient(predicts, dtrain); + boost(dtrain, gradients.get(0), gradients.get(1)); + } + + /** + * update with give grad and hess + * @param dtrain training data + * @param grad first order of gradient + * @param hess seconde order of gradient + */ + public void boost(DMatrix dtrain, float[] grad, float[] hess) { + if(grad.length != hess.length) { + throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length, hess.length)); + } + XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad, hess); + } + + /** + * evaluate with given dmatrixs. + * @param evalMatrixs dmatrixs for evaluation + * @param evalNames name for eval dmatrixs, used for check results + * @param iter current eval iteration + * @return eval information + */ + public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) { + long[] handles = dMatrixs2handles(evalMatrixs); + String evalInfo = XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames); + return evalInfo; + } + + /** + * evaluate with given customized Evaluation class + * @param evalMatrixs + * @param evalNames + * @param iter + * @param eval + * @return eval information + */ + public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter, IEvaluation eval) { + String evalInfo = ""; + for(int i=0; i getFeatureScore() { + String[] modelInfos = getDumpInfo(false); + Map featureScore = new HashMap<>(); + for(String tree : modelInfos) { + for(String node : tree.split("\n")) { + String[] array = node.split("\\["); + if(array.length == 1) { + continue; + } + String fid = array[1].split("\\]")[0]; + fid = fid.split("<")[0]; + if(featureScore.containsKey(fid)) { + featureScore.put(fid, 1 + featureScore.get(fid)); + } + else { + featureScore.put(fid, 1); + } + } + } + return featureScore; + } + + + /** + * get importance of each feature + * @param featureMap file to save dumped model info + * @return featureMap key: feature index, value: feature importance score + */ + public Map getFeatureScore(String featureMap) { + String[] modelInfos = getDumpInfo(featureMap, false); + Map featureScore = new HashMap<>(); + for(String tree : modelInfos) { + for(String node : tree.split("\n")) { + String[] array = node.split("\\["); + if(array.length == 1) { + continue; + } + String fid = array[1].split("\\]")[0]; + fid = fid.split("<")[0]; + if(featureScore.containsKey(fid)) { + featureScore.put(fid, 1 + featureScore.get(fid)); + } + else { + featureScore.put(fid, 1); + } + } + } + return featureScore; + } + + /** + * transfer DMatrix array to handle array (used for native functions) + * @param dmatrixs + * @return handle array for input dmatrixs + */ + private static long[] dMatrixs2handles(DMatrix[] dmatrixs) { + long[] handles = new long[dmatrixs.length]; + for(int i=0; i getGradient(float[][] predicts, DMatrix dtrain); +} diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/CVPack.java b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/CVPack.java new file mode 100644 index 000000000..3e67dc669 --- /dev/null +++ b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/CVPack.java @@ -0,0 +1,84 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package org.dmlc.xgboost4j.util; + +import java.util.Map; +import org.dmlc.xgboost4j.IEvaluation; +import org.dmlc.xgboost4j.Booster; +import org.dmlc.xgboost4j.DMatrix; +import org.dmlc.xgboost4j.IObjective; + +/** + * cross validation package for xgb + * @author hzx + */ +public class CVPack { + DMatrix dtrain; + DMatrix dtest; + DMatrix[] dmats; + String[] names; + Booster booster; + + /** + * create an cross validation package + * @param dtrain train data + * @param dtest test data + * @param params parameters + */ + public CVPack(DMatrix dtrain, DMatrix dtest, Iterable> params) { + dmats = new DMatrix[] {dtrain, dtest}; + booster = new Booster(params, dmats); + names = new String[] {"train", "test"}; + this.dtrain = dtrain; + this.dtest = dtest; + } + + /** + * update one iteration + * @param iter iteration num + */ + public void update(int iter) { + booster.update(dtrain, iter); + } + + /** + * update one iteration + * @param iter iteration num + * @param obj customized objective + */ + public void update(int iter, IObjective obj) { + booster.update(dtrain, iter, obj); + } + + /** + * evaluation + * @param iter iteration num + * @return + */ + public String eval(int iter) { + return booster.evalSet(dmats, names, iter); + } + + /** + * evaluation + * @param iter iteration num + * @param eval customized eval + * @return + */ + public String eval(int iter, IEvaluation eval) { + return booster.evalSet(dmats, names, iter, eval); + } +} diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Initializer.java b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Initializer.java new file mode 100644 index 000000000..83932ce84 --- /dev/null +++ b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Initializer.java @@ -0,0 +1,92 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package org.dmlc.xgboost4j.util; + +import java.io.IOException; +import java.lang.reflect.Field; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +/** + * class to load native library + * @author hzx + */ +public class Initializer { + private static final Log logger = LogFactory.getLog(Initializer.class); + + static boolean initialized = false; + public static final String nativePath = "./lib"; + public static final String nativeResourcePath = "/lib/"; + public static final String[] libNames = new String[] {"xgboostjavawrapper"}; + + public static synchronized void InitXgboost() throws IOException { + if(initialized == false) { + for(String libName: libNames) { + smartLoad(libName); + } + initialized = true; + } + } + + /** + * load native library, this method will first try to load library from java.library.path, then try to load library in jar package. + * @param libName + * @throws IOException + */ + private static void smartLoad(String libName) throws IOException { + addNativeDir(nativePath); + try { + System.loadLibrary(libName); + } + catch (UnsatisfiedLinkError e) { + try { + NativeUtils.loadLibraryFromJar(nativeResourcePath + System.mapLibraryName(libName)); + } + catch (IOException e1) { + throw e1; + } + } + } + + /** + * add libPath to java.library.path, then native library in libPath would be load properly + * @param libPath + * @throws IOException + */ + public static void addNativeDir(String libPath) throws IOException { + try { + Field field = ClassLoader.class.getDeclaredField("usr_paths"); + field.setAccessible(true); + String[] paths = (String[]) field.get(null); + for (String path : paths) { + if (libPath.equals(path)) { + return; + } + } + String[] tmp = new String[paths.length+1]; + System.arraycopy(paths,0,tmp,0,paths.length); + tmp[paths.length] = libPath; + field.set(null, tmp); + } catch (IllegalAccessException e) { + logger.error(e.getMessage()); + throw new IOException("Failed to get permissions to set library path"); + } catch (NoSuchFieldException e) { + logger.error(e.getMessage()); + throw new IOException("Failed to get field handle to set library path"); + } + } +} diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/NativeUtils.java b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/NativeUtils.java new file mode 100644 index 000000000..c0f199005 --- /dev/null +++ b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/NativeUtils.java @@ -0,0 +1,109 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package org.dmlc.xgboost4j.util; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; + + +/** + * Simple library class for working with JNI (Java Native Interface) + * + * @see http://adamheinrich.com/2012/how-to-load-native-jni-library-from-jar + * + * @author Adam Heirnich <adam@adamh.cz>, http://www.adamh.cz + */ +public class NativeUtils { + + /** + * Private constructor - this class will never be instanced + */ + private NativeUtils() { + } + + /** + * Loads library from current JAR archive + * + * The file from JAR is copied into system temporary directory and then loaded. The temporary file is deleted after exiting. + * Method uses String as filename because the pathname is "abstract", not system-dependent. + * + * @param path The filename inside JAR as absolute path (beginning with '/'), e.g. /package/File.ext + * @throws IOException If temporary file creation or read/write operation fails + * @throws IllegalArgumentException If source file (param path) does not exist + * @throws IllegalArgumentException If the path is not absolute or if the filename is shorter than three characters (restriction of {@see File#createTempFile(java.lang.String, java.lang.String)}). + */ + public static void loadLibraryFromJar(String path) throws IOException { + + if (!path.startsWith("/")) { + throw new IllegalArgumentException("The path has to be absolute (start with '/')."); + } + + // Obtain filename from path + String[] parts = path.split("/"); + String filename = (parts.length > 1) ? parts[parts.length - 1] : null; + + // Split filename to prexif and suffix (extension) + String prefix = ""; + String suffix = null; + if (filename != null) { + parts = filename.split("\\.", 2); + prefix = parts[0]; + suffix = (parts.length > 1) ? "."+parts[parts.length - 1] : null; // Thanks, davs! :-) + } + + // Check if the filename is okay + if (filename == null || prefix.length() < 3) { + throw new IllegalArgumentException("The filename has to be at least 3 characters long."); + } + + // Prepare temporary file + File temp = File.createTempFile(prefix, suffix); + temp.deleteOnExit(); + + if (!temp.exists()) { + throw new FileNotFoundException("File " + temp.getAbsolutePath() + " does not exist."); + } + + // Prepare buffer for data copying + byte[] buffer = new byte[1024]; + int readBytes; + + // Open and check input stream + InputStream is = NativeUtils.class.getResourceAsStream(path); + if (is == null) { + throw new FileNotFoundException("File " + path + " was not found inside JAR."); + } + + // Open output stream and copy data between source file in JAR and the temporary file + OutputStream os = new FileOutputStream(temp); + try { + while ((readBytes = is.read(buffer)) != -1) { + os.write(buffer, 0, readBytes); + } + } finally { + // If read/write fails, close streams safely before throwing an exception + os.close(); + is.close(); + } + + // Finally, load the library + System.load(temp.getAbsolutePath()); + } +} diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Trainer.java b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Trainer.java new file mode 100644 index 000000000..8a336b1a8 --- /dev/null +++ b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Trainer.java @@ -0,0 +1,235 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package org.dmlc.xgboost4j.util; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.dmlc.xgboost4j.IEvaluation; +import org.dmlc.xgboost4j.Booster; +import org.dmlc.xgboost4j.DMatrix; +import org.dmlc.xgboost4j.IObjective; + + +/** + * trainer for xgboost + * @author hzx + */ +public class Trainer { + private static final Log logger = LogFactory.getLog(Trainer.class); + + /** + * Train a booster with given parameters. + * @param params Booster params. + * @param dtrain Data to be trained. + * @param round Number of boosting iterations. + * @param watchs a group of items to be evaluated during training, this allows user to watch performance on the validation set. + * @param obj customized objective (set to null if not used) + * @param eval customized evaluation (set to null if not used) + * @return trained booster + */ + public static Booster train(Iterable> params, DMatrix dtrain, int round, + Iterable> watchs, IObjective obj, IEvaluation eval) { + + //collect eval matrixs + String[] evalNames; + DMatrix[] evalMats; + List names = new ArrayList<>(); + List mats = new ArrayList<>(); + + for(Entry evalEntry : watchs) { + names.add(evalEntry.getKey()); + mats.add(evalEntry.getValue()); + } + + evalNames = names.toArray(new String[names.size()]); + evalMats = mats.toArray(new DMatrix[mats.size()]); + + //collect all data matrixs + DMatrix[] allMats; + if(evalMats!=null && evalMats.length>0) { + allMats = new DMatrix[evalMats.length+1]; + allMats[0] = dtrain; + System.arraycopy(evalMats, 0, allMats, 1, evalMats.length); + } + else { + allMats = new DMatrix[1]; + allMats[0] = dtrain; + } + + //initialize booster + Booster booster = new Booster(params, allMats); + + //begin to train + for(int iter=0; iter0) { + String evalInfo; + if(eval != null) { + evalInfo = booster.evalSet(evalMats, evalNames, iter, eval); + } + else { + evalInfo = booster.evalSet(evalMats, evalNames, iter); + } + logger.info(evalInfo); + } + } + return booster; + } + + /** + * Cross-validation with given paramaters. + * @param params Booster params. + * @param data Data to be trained. + * @param round Number of boosting iterations. + * @param nfold Number of folds in CV. + * @param metrics Evaluation metrics to be watched in CV. + * @param obj customized objective (set to null if not used) + * @param eval customized evaluation (set to null if not used) + * @return evaluation history + */ + public static String[] crossValiation(Iterable> params, DMatrix data, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) { + CVPack[] cvPacks = makeNFold(data, nfold, params, metrics); + String[] evalHist = new String[round]; + String[] results = new String[cvPacks.length]; + for(int i=0; i> params, String[] evalMetrics) { + List samples = genRandPermutationNums(0, (int) data.rowNum()); + int step = samples.size()/nfold; + int[] testSlice = new int[step]; + int[] trainSlice = new int[samples.size()-step]; + int testid, trainid; + CVPack[] cvPacks = new CVPack[nfold]; + for(int i=0; i(i*step) && j<(i*step+step) && testid genRandPermutationNums(int start, int end) { + List samples = new ArrayList<>(); + for(int i=start; i > cvMap = new HashMap<>(); + String aggResult = results[0].split("\t")[0]; + for(String result : results) { + String[] items = result.split("\t"); + for(int i=1; i()); + } + cvMap.get(key).add(value); + } + } + + for(String key : cvMap.keySet()) { + float value = 0f; + for(Float tvalue : cvMap.get(key)) { + value += tvalue; + } + value /= cvMap.get(key).size(); + aggResult += String.format("\tcv-%s:%f", key, value); + } + + return aggResult; + } +} diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/wrapper/XgboostJNI.java b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/wrapper/XgboostJNI.java new file mode 100644 index 000000000..96a429c07 --- /dev/null +++ b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/wrapper/XgboostJNI.java @@ -0,0 +1,48 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package org.dmlc.xgboost4j.wrapper; + +/** + * xgboost jni wrapper functions for xgboost_wrapper.h + * @author hzx + */ +public class XgboostJNI { + public final static native long XGDMatrixCreateFromFile(String fname, int silent); + public final static native long XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data); + public final static native long XGDMatrixCreateFromCSC(long[] colptr, int[] indices, float[] data); + public final static native long XGDMatrixCreateFromMat(float[] data, int nrow, int ncol, float missing); + public final static native long XGDMatrixSliceDMatrix(long handle, int[] idxset); + public final static native void XGDMatrixFree(long handle); + public final static native void XGDMatrixSaveBinary(long handle, String fname, int silent); + public final static native void XGDMatrixSetFloatInfo(long handle, String field, float[] array); + public final static native void XGDMatrixSetUIntInfo(long handle, String field, int[] array); + public final static native void XGDMatrixSetGroup(long handle, int[] group); + public final static native float[] XGDMatrixGetFloatInfo(long handle, String field); + public final static native int[] XGDMatrixGetUIntInfo(long handle, String filed); + public final static native long XGDMatrixNumRow(long handle); + public final static native long XGBoosterCreate(long[] handles); + public final static native void XGBoosterFree(long handle); + public final static native void XGBoosterSetParam(long handle, String name, String value); + public final static native void XGBoosterUpdateOneIter(long handle, int iter, long dtrain); + public final static native void XGBoosterBoostOneIter(long handle, long dtrain, float[] grad, float[] hess); + public final static native String XGBoosterEvalOneIter(long handle, int iter, long[] dmats, String[] evnames); + public final static native float[] XGBoosterPredict(long handle, long dmat, int option_mask, long ntree_limit); + public final static native void XGBoosterLoadModel(long handle, String fname); + public final static native void XGBoosterSaveModel(long handle, String fname); + public final static native void XGBoosterLoadModelFromBuffer(long handle, long buf, long len); + public final static native String XGBoosterGetModelRaw(long handle); + public final static native String[] XGBoosterDumpModel(long handle, String fmap, int with_stats); +} diff --git a/java/xgboost4j_wrapper.cpp b/java/xgboost4j_wrapper.cpp new file mode 100644 index 000000000..55dc31bc8 --- /dev/null +++ b/java/xgboost4j_wrapper.cpp @@ -0,0 +1,634 @@ +/* + Copyright (c) 2014 by Contributors + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +#include +#include "../wrapper/xgboost_wrapper.h" +#include "xgboost4j_wrapper.h" + +JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromFile + (JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent) { + jlong jresult = 0 ; + char *fname = (char *) 0 ; + int silent; + void *result = 0 ; + fname = 0; + if (jfname) { + fname = (char *)jenv->GetStringUTFChars(jfname, 0); + if (!fname) return 0; + } + silent = (int)jsilent; + result = (void *)XGDMatrixCreateFromFile((char const *)fname, silent); + *(void **)&jresult = result; + if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname); + return jresult; +} + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGDMatrixCreateFromCSR + * Signature: ([J[J[F)J + */ +JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSR + (JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata) { + jlong jresult = 0 ; + bst_ulong nindptr ; + bst_ulong nelem; + void *result = 0 ; + + jlong* indptr = jenv->GetLongArrayElements(jindptr, 0); + jint* indices = jenv->GetIntArrayElements(jindices, 0); + jfloat* data = jenv->GetFloatArrayElements(jdata, 0); + nindptr = (bst_ulong)jenv->GetArrayLength(jindptr); + nelem = (bst_ulong)jenv->GetArrayLength(jdata); + + result = (void *)XGDMatrixCreateFromCSR((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem); + *(void **)&jresult = result; + + //release + jenv->ReleaseLongArrayElements(jindptr, indptr, 0); + jenv->ReleaseIntArrayElements(jindices, indices, 0); + jenv->ReleaseFloatArrayElements(jdata, data, 0); + + return jresult; +} + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGDMatrixCreateFromCSC + * Signature: ([J[J[F)J + */ +JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSC + (JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata) { + jlong jresult = 0 ; + bst_ulong nindptr ; + bst_ulong nelem; + void *result = 0 ; + + jlong* indptr = jenv->GetLongArrayElements(jindptr, NULL); + jint* indices = jenv->GetIntArrayElements(jindices, 0); + jfloat* data = jenv->GetFloatArrayElements(jdata, NULL); + nindptr = (bst_ulong)jenv->GetArrayLength(jindptr); + nelem = (bst_ulong)jenv->GetArrayLength(jdata); + + result = (void *)XGDMatrixCreateFromCSC((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem); + *(void **)&jresult = result; + + //release + jenv->ReleaseLongArrayElements(jindptr, indptr, 0); + jenv->ReleaseIntArrayElements(jindices, indices, 0); + jenv->ReleaseFloatArrayElements(jdata, data, 0); + + return jresult; +} + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGDMatrixCreateFromMat + * Signature: ([FIIF)J + */ +JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromMat + (JNIEnv *jenv, jclass jcls, jfloatArray jdata, jint jnrow, jint jncol, jfloat jmiss) { + jlong jresult = 0 ; + bst_ulong nrow ; + bst_ulong ncol ; + float miss ; + void *result = 0 ; + + + jfloat* data = jenv->GetFloatArrayElements(jdata, 0); + nrow = (bst_ulong)jnrow; + ncol = (bst_ulong)jncol; + miss = (float)jmiss; + result = (void *)XGDMatrixCreateFromMat((float const *)data, nrow, ncol, miss); + *(void **)&jresult = result; + + //release + jenv->ReleaseFloatArrayElements(jdata, data, 0); + + return jresult; +} + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGDMatrixSliceDMatrix + * Signature: (J[I)J + */ +JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSliceDMatrix + (JNIEnv *jenv, jclass jcls, jlong jhandle, jintArray jindexset) { + jlong jresult = 0 ; + void *handle = (void *) 0 ; + bst_ulong len; + void *result = 0 ; + + jint* indexset = jenv->GetIntArrayElements(jindexset, 0); + handle = *(void **)&jhandle; + len = (bst_ulong)jenv->GetArrayLength(jindexset); + + result = (void *)XGDMatrixSliceDMatrix(handle, (int const *)indexset, len); + *(void **)&jresult = result; + + //release + jenv->ReleaseIntArrayElements(jindexset, indexset, 0); + + return jresult; +} + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGDMatrixFree + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixFree + (JNIEnv *jenv, jclass jcls, jlong jhandle) { + void *handle = (void *) 0 ; + handle = *(void **)&jhandle; + XGDMatrixFree(handle); +} + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGDMatrixSaveBinary + * Signature: (JLjava/lang/String;I)V + */ +JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSaveBinary + (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname, jint jsilent) { + void *handle = (void *) 0 ; + char *fname = (char *) 0 ; + int silent ; + handle = *(void **)&jhandle; + fname = 0; + if (jfname) { + fname = (char *)jenv->GetStringUTFChars(jfname, 0); + if (!fname) return ; + } + silent = (int)jsilent; + XGDMatrixSaveBinary(handle, (char const *)fname, silent); + if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname); +} + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGDMatrixSetFloatInfo + * Signature: (JLjava/lang/String;[F)V + */ +JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetFloatInfo + (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jfloatArray jarray) { + void *handle = (void *) 0 ; + char *field = (char *) 0 ; + bst_ulong len; + + + handle = *(void **)&jhandle; + field = 0; + if (jfield) { + field = (char *)jenv->GetStringUTFChars(jfield, 0); + if (!field) return ; + } + + jfloat* array = jenv->GetFloatArrayElements(jarray, NULL); + len = (bst_ulong)jenv->GetArrayLength(jarray); + XGDMatrixSetFloatInfo(handle, (char const *)field, (float const *)array, len); + + //release + if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field); + jenv->ReleaseFloatArrayElements(jarray, array, 0); +} + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGDMatrixSetUIntInfo + * Signature: (JLjava/lang/String;[I)V + */ +JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetUIntInfo + (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jintArray jarray) { + void *handle = (void *) 0 ; + char *field = (char *) 0 ; + bst_ulong len ; + handle = *(void **)&jhandle; + field = 0; + if (jfield) { + field = (char *)jenv->GetStringUTFChars(jfield, 0); + if (!field) return ; + } + + jint* array = jenv->GetIntArrayElements(jarray, NULL); + len = (bst_ulong)jenv->GetArrayLength(jarray); + + XGDMatrixSetUIntInfo(handle, (char const *)field, (unsigned int const *)array, len); + //release + if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field); + jenv->ReleaseIntArrayElements(jarray, array, 0); +} + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGDMatrixSetGroup + * Signature: (J[I)V + */ +JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGroup + (JNIEnv * jenv, jclass jcls, jlong jhandle, jintArray jarray) { + void *handle = (void *) 0 ; + bst_ulong len ; + + handle = *(void **)&jhandle; + jint* array = jenv->GetIntArrayElements(jarray, NULL); + len = (bst_ulong)jenv->GetArrayLength(jarray); + + XGDMatrixSetGroup(handle, (unsigned int const *)array, len); + + //release + jenv->ReleaseIntArrayElements(jarray, array, 0); + +} + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGDMatrixGetFloatInfo + * Signature: (JLjava/lang/String;)[F + */ +JNIEXPORT jfloatArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetFloatInfo + (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield) { + void *handle = (void *) 0 ; + char *field = (char *) 0 ; + bst_ulong len[1]; + *len = 0; + float *result = 0 ; + + handle = *(void **)&jhandle; + field = 0; + if (jfield) { + field = (char *)jenv->GetStringUTFChars(jfield, 0); + if (!field) return 0; + } + + result = (float *)XGDMatrixGetFloatInfo((void const *)handle, (char const *)field, len); + + if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field); + + jsize jlen = (jsize)*len; + jfloatArray jresult = jenv->NewFloatArray(jlen); + jenv->SetFloatArrayRegion(jresult, 0, jlen, (jfloat *)result); + return jresult; +} + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGDMatrixGetUIntInfo + * Signature: (JLjava/lang/String;)[I + */ +JNIEXPORT jintArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetUIntInfo + (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield) { + void *handle = (void *) 0 ; + char *field = (char *) 0 ; + bst_ulong len[1]; + *len = 0; + unsigned int *result = 0 ; + + handle = *(void **)&jhandle; + field = 0; + if (jfield) { + field = (char *)jenv->GetStringUTFChars(jfield, 0); + if (!field) return 0; + } + + result = (unsigned int *)XGDMatrixGetUIntInfo((void const *)handle, (char const *)field, len); + + if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field); + + jsize jlen = (jsize)*len; + jintArray jresult = jenv->NewIntArray(jlen); + jenv->SetIntArrayRegion(jresult, 0, jlen, (jint *)result); + return jresult; +} + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGDMatrixNumRow + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixNumRow + (JNIEnv *jenv, jclass jcls, jlong jhandle) { + jlong jresult = 0 ; + void *handle = (void *) 0 ; + bst_ulong result; + handle = *(void **)&jhandle; + result = (bst_ulong)XGDMatrixNumRow((void const *)handle); + jresult = (jlong)result; + return jresult; +} + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGBoosterCreate + * Signature: ([J)J + */ +JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCreate + (JNIEnv *jenv, jclass jcls, jlongArray jhandles) { + jlong jresult = 0 ; + void **handles = 0; + bst_ulong len = 0; + void *result = 0 ; + jlong* cjhandles = 0; + + if(jhandles) { + len = (bst_ulong)jenv->GetArrayLength(jhandles); + handles = new void*[len]; + //put handle from jhandles to chandles + cjhandles = jenv->GetLongArrayElements(jhandles, 0); + for(bst_ulong i=0; iReleaseLongArrayElements(jhandles, cjhandles, 0); + } + + *(void **)&jresult = result; + return jresult; +} + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGBoosterFree + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterFree + (JNIEnv *jenv, jclass jcls, jlong jhandle) { + void *handle = (void *) 0 ; + handle = *(void **)&jhandle; + XGBoosterFree(handle); +} + + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGBoosterSetParam + * Signature: (JLjava/lang/String;Ljava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSetParam + (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jname, jstring jvalue) { + void *handle = (void *) 0 ; + char *name = (char *) 0 ; + char *value = (char *) 0 ; + handle = *(void **)&jhandle; + + name = 0; + if (jname) { + name = (char *)jenv->GetStringUTFChars(jname, 0); + if (!name) return ; + } + + value = 0; + if (jvalue) { + value = (char *)jenv->GetStringUTFChars(jvalue, 0); + if (!value) return ; + } + XGBoosterSetParam(handle, (char const *)name, (char const *)value); + if (name) jenv->ReleaseStringUTFChars(jname, (const char *)name); + if (value) jenv->ReleaseStringUTFChars(jvalue, (const char *)value); +} + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGBoosterUpdateOneIter + * Signature: (JIJ)V + */ +JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdateOneIter + (JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlong jdtrain) { + void *handle = (void *) 0 ; + int iter ; + void *dtrain = (void *) 0 ; + handle = *(void **)&jhandle; + iter = (int)jiter; + dtrain = *(void **)&jdtrain; + XGBoosterUpdateOneIter(handle, iter, dtrain); +} + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGBoosterBoostOneIter + * Signature: (JJ[F[F)V + */ +JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoostOneIter + (JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jfloatArray jgrad, jfloatArray jhess) { + void *handle = (void *) 0 ; + void *dtrain = (void *) 0 ; + bst_ulong len ; + + handle = *(void **)&jhandle; + dtrain = *(void **)&jdtrain; + jfloat* grad = jenv->GetFloatArrayElements(jgrad, 0); + jfloat* hess = jenv->GetFloatArrayElements(jhess, 0); + len = (bst_ulong)jenv->GetArrayLength(jgrad); + XGBoosterBoostOneIter(handle, dtrain, grad, hess, len); + + //release + jenv->ReleaseFloatArrayElements(jgrad, grad, 0); + jenv->ReleaseFloatArrayElements(jhess, hess, 0); +} + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGBoosterEvalOneIter + * Signature: (JI[J[Ljava/lang/String;)Ljava/lang/String; + */ +JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEvalOneIter + (JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlongArray jdmats, jobjectArray jevnames) { + jstring jresult = 0 ; + void *handle = (void *) 0 ; + int iter ; + void **dmats = 0; + char **evnames = 0; + bst_ulong len ; + char *result = 0 ; + + handle = *(void **)&jhandle; + iter = (int)jiter; + len = (bst_ulong)jenv->GetArrayLength(jdmats); + + + if(len > 0) { + dmats = new void*[len]; + evnames = new char*[len]; + } + + //put handle from jhandles to chandles + jlong* cjdmats = jenv->GetLongArrayElements(jdmats, 0); + for(bst_ulong i=0; iGetObjectArrayElement(jevnames, i); + evnames[i] = (char *)jenv->GetStringUTFChars(jevname, 0); + } + + result = (char *)XGBoosterEvalOneIter(handle, iter, dmats, (char const *(*))evnames, len); + + if(len > 0) { + delete[] dmats; + //release string chars + for(bst_ulong i=0; iGetObjectArrayElement(jevnames, i); + jenv->ReleaseStringUTFChars(jevname, (const char*)evnames[i]); + } + delete[] evnames; + jenv->ReleaseLongArrayElements(jdmats, cjdmats, 0); + } + + if (result) jresult = jenv->NewStringUTF((const char *)result); + + return jresult; +} + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGBoosterPredict + * Signature: (JJIJ)[F + */ +JNIEXPORT jfloatArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterPredict + (JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdmat, jint joption_mask, jlong jntree_limit) { + void *handle = (void *) 0 ; + void *dmat = (void *) 0 ; + int option_mask ; + unsigned int ntree_limit ; + bst_ulong len[1]; + *len = 0; + float *result = 0 ; + + handle = *(void **)&jhandle; + dmat = *(void **)&jdmat; + option_mask = (int)joption_mask; + ntree_limit = (unsigned int)jntree_limit; + + result = (float *)XGBoosterPredict(handle, dmat, option_mask, ntree_limit, len); + + jsize jlen = (jsize)*len; + jfloatArray jresult = jenv->NewFloatArray(jlen); + jenv->SetFloatArrayRegion(jresult, 0, jlen, (jfloat *)result); + return jresult; +} + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGBoosterLoadModel + * Signature: (JLjava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModel + (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) { + void *handle = (void *) 0 ; + char *fname = (char *) 0 ; + handle = *(void **)&jhandle; + fname = 0; + if (jfname) { + fname = (char *)jenv->GetStringUTFChars(jfname, 0); + if (!fname) return ; + } + XGBoosterLoadModel(handle,(char const *)fname); + if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname); +} + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGBoosterSaveModel + * Signature: (JLjava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSaveModel + (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) { + void *handle = (void *) 0 ; + char *fname = (char *) 0 ; + handle = *(void **)&jhandle; + fname = 0; + if (jfname) { + fname = (char *)jenv->GetStringUTFChars(jfname, 0); + if (!fname) return ; + } + XGBoosterSaveModel(handle, (char const *)fname); + if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname); +} + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGBoosterLoadModelFromBuffer + * Signature: (JJJ)V + */ +JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModelFromBuffer + (JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jbuf, jlong jlen) { + void *handle = (void *) 0 ; + void *buf = (void *) 0 ; + bst_ulong len ; + handle = *(void **)&jhandle; + buf = *(void **)&jbuf; + len = (bst_ulong)jlen; + XGBoosterLoadModelFromBuffer(handle, (void const *)buf, len); +} + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGBoosterGetModelRaw + * Signature: (J)Ljava/lang/String; + */ +JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGetModelRaw + (JNIEnv * jenv, jclass jcls, jlong jhandle) { + jstring jresult = 0 ; + void *handle = (void *) 0 ; + bst_ulong len[1]; + *len = 0; + char *result = 0 ; + handle = *(void **)&jhandle; + + result = (char *)XGBoosterGetModelRaw(handle, len); + if (result) jresult = jenv->NewStringUTF((const char *)result); + return jresult; +} + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGBoosterDumpModel + * Signature: (JLjava/lang/String;I)[Ljava/lang/String; + */ +JNIEXPORT jobjectArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterDumpModel + (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfmap, jint jwith_stats) { + void *handle = (void *) 0 ; + char *fmap = (char *) 0 ; + int with_stats ; + bst_ulong len[1]; + *len = 0; + + char **result = 0 ; + handle = *(void **)&jhandle; + fmap = 0; + if (jfmap) { + fmap = (char *)jenv->GetStringUTFChars(jfmap, 0); + if (!fmap) return 0; + } + with_stats = (int)jwith_stats; + + result = (char **)XGBoosterDumpModel(handle, (char const *)fmap, with_stats, len); + + jsize jlen = (jsize)*len; + jobjectArray jresult = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF("")); + for(int i=0 ; iSetObjectArrayElement(jresult, i, jenv->NewStringUTF((const char*)result[i])); + } + + if (fmap) jenv->ReleaseStringUTFChars(jfmap, (const char *)fmap); + return jresult; +} \ No newline at end of file diff --git a/java/xgboost4j_wrapper.h b/java/xgboost4j_wrapper.h new file mode 100644 index 000000000..d13b86f8c --- /dev/null +++ b/java/xgboost4j_wrapper.h @@ -0,0 +1,213 @@ +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class org_dmlc_xgboost4j_wrapper_XgboostJNI */ + +#ifndef _Included_org_dmlc_xgboost4j_wrapper_XgboostJNI +#define _Included_org_dmlc_xgboost4j_wrapper_XgboostJNI +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGDMatrixCreateFromFile + * Signature: (Ljava/lang/String;I)J + */ +JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromFile + (JNIEnv *, jclass, jstring, jint); + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGDMatrixCreateFromCSR + * Signature: ([J[J[F)J + */ +JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSR + (JNIEnv *, jclass, jlongArray, jintArray, jfloatArray); + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGDMatrixCreateFromCSC + * Signature: ([J[J[F)J + */ +JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSC + (JNIEnv *, jclass, jlongArray, jintArray, jfloatArray); + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGDMatrixCreateFromMat + * Signature: ([FIIF)J + */ +JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromMat + (JNIEnv *, jclass, jfloatArray, jint, jint, jfloat); + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGDMatrixSliceDMatrix + * Signature: (J[I)J + */ +JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSliceDMatrix + (JNIEnv *, jclass, jlong, jintArray); + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGDMatrixFree + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixFree + (JNIEnv *, jclass, jlong); + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGDMatrixSaveBinary + * Signature: (JLjava/lang/String;I)V + */ +JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSaveBinary + (JNIEnv *, jclass, jlong, jstring, jint); + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGDMatrixSetFloatInfo + * Signature: (JLjava/lang/String;[F)V + */ +JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetFloatInfo + (JNIEnv *, jclass, jlong, jstring, jfloatArray); + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGDMatrixSetUIntInfo + * Signature: (JLjava/lang/String;[I)V + */ +JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetUIntInfo + (JNIEnv *, jclass, jlong, jstring, jintArray); + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGDMatrixSetGroup + * Signature: (J[I)V + */ +JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGroup + (JNIEnv *, jclass, jlong, jintArray); + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGDMatrixGetFloatInfo + * Signature: (JLjava/lang/String;)[F + */ +JNIEXPORT jfloatArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetFloatInfo + (JNIEnv *, jclass, jlong, jstring); + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGDMatrixGetUIntInfo + * Signature: (JLjava/lang/String;)[I + */ +JNIEXPORT jintArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetUIntInfo + (JNIEnv *, jclass, jlong, jstring); + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGDMatrixNumRow + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixNumRow + (JNIEnv *, jclass, jlong); + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGBoosterCreate + * Signature: ([J)J + */ +JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCreate + (JNIEnv *, jclass, jlongArray); + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGBoosterFree + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterFree + (JNIEnv *, jclass, jlong); + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGBoosterSetParam + * Signature: (JLjava/lang/String;Ljava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSetParam + (JNIEnv *, jclass, jlong, jstring, jstring); + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGBoosterUpdateOneIter + * Signature: (JIJ)V + */ +JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdateOneIter + (JNIEnv *, jclass, jlong, jint, jlong); + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGBoosterBoostOneIter + * Signature: (JJ[F[F)V + */ +JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoostOneIter + (JNIEnv *, jclass, jlong, jlong, jfloatArray, jfloatArray); + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGBoosterEvalOneIter + * Signature: (JI[J[Ljava/lang/String;)Ljava/lang/String; + */ +JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEvalOneIter + (JNIEnv *, jclass, jlong, jint, jlongArray, jobjectArray); + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGBoosterPredict + * Signature: (JJIJ)[F + */ +JNIEXPORT jfloatArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterPredict + (JNIEnv *, jclass, jlong, jlong, jint, jlong); + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGBoosterLoadModel + * Signature: (JLjava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModel + (JNIEnv *, jclass, jlong, jstring); + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGBoosterSaveModel + * Signature: (JLjava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSaveModel + (JNIEnv *, jclass, jlong, jstring); + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGBoosterLoadModelFromBuffer + * Signature: (JJJ)V + */ +JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModelFromBuffer + (JNIEnv *, jclass, jlong, jlong, jlong); + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGBoosterGetModelRaw + * Signature: (J)Ljava/lang/String; + */ +JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGetModelRaw + (JNIEnv *, jclass, jlong); + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGBoosterDumpModel + * Signature: (JLjava/lang/String;I)[Ljava/lang/String; + */ +JNIEXPORT jobjectArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterDumpModel + (JNIEnv *, jclass, jlong, jstring, jint); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/windows/xgboost.sln b/windows/xgboost.sln index f2b08a456..7bd8db5b2 100644 --- a/windows/xgboost.sln +++ b/windows/xgboost.sln @@ -10,6 +10,8 @@ Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "xgboost_wrapper", "xgboost_ EndProject Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "rabit", "..\subtree\rabit\windows\rabit\rabit.vcxproj", "{D7B77D06-4F5F-4BD7-B81E-7CC8EBBE684F}" EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "xgboostjavawrapper", "xgboostjavawrapper\xgboostjavawrapper.vcxproj", "{20A0E4D7-20C7-4EC1-BDF6-0D469CE239AA}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Win32 = Debug|Win32 @@ -41,6 +43,14 @@ Global {D7B77D06-4F5F-4BD7-B81E-7CC8EBBE684F}.Release|Win32.Build.0 = Release|Win32 {D7B77D06-4F5F-4BD7-B81E-7CC8EBBE684F}.Release|x64.ActiveCfg = Release|x64 {D7B77D06-4F5F-4BD7-B81E-7CC8EBBE684F}.Release|x64.Build.0 = Release|x64 + {20A0E4D7-20C7-4EC1-BDF6-0D469CE239AA}.Debug|Win32.ActiveCfg = Debug|Win32 + {20A0E4D7-20C7-4EC1-BDF6-0D469CE239AA}.Debug|Win32.Build.0 = Debug|Win32 + {20A0E4D7-20C7-4EC1-BDF6-0D469CE239AA}.Debug|x64.ActiveCfg = Debug|x64 + {20A0E4D7-20C7-4EC1-BDF6-0D469CE239AA}.Debug|x64.Build.0 = Debug|x64 + {20A0E4D7-20C7-4EC1-BDF6-0D469CE239AA}.Release|Win32.ActiveCfg = Release|Win32 + {20A0E4D7-20C7-4EC1-BDF6-0D469CE239AA}.Release|Win32.Build.0 = Release|Win32 + {20A0E4D7-20C7-4EC1-BDF6-0D469CE239AA}.Release|x64.ActiveCfg = Release|x64 + {20A0E4D7-20C7-4EC1-BDF6-0D469CE239AA}.Release|x64.Build.0 = Release|x64 EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/windows/xgboostjavawrapper/xgboostjavawrapper.vcxproj b/windows/xgboostjavawrapper/xgboostjavawrapper.vcxproj new file mode 100644 index 000000000..e55dfff71 --- /dev/null +++ b/windows/xgboostjavawrapper/xgboostjavawrapper.vcxproj @@ -0,0 +1,129 @@ + + + + + Debug + Win32 + + + Debug + x64 + + + Release + Win32 + + + Release + x64 + + + + + + + + + + + + + {20A0E4D7-20C7-4EC1-BDF6-0D469CE239AA} + xgboost_wrapper + + + + DynamicLibrary + true + MultiByte + + + DynamicLibrary + true + MultiByte + + + DynamicLibrary + false + true + MultiByte + + + DynamicLibrary + false + true + MultiByte + + + + + + + + + + + + + + + + + + + $(SolutionDir)$(Platform)\$(Configuration)\ + + + + Level3 + Disabled + + + true + + + + + Level3 + Disabled + + + true + + + + + Level3 + MaxSpeed + true + true + true + $(JAVA_HOME)\include;$(JAVA_HOME)\include\win32;%(AdditionalIncludeDirectories) + + + true + true + true + + + + + Level3 + MaxSpeed + true + true + true + MultiThreaded + $(JAVA_HOME)\include\win32;$(JAVA_HOME)\include;%(AdditionalIncludeDirectories) + + + true + true + true + ws2_32.lib;%(AdditionalDependencies) + + + + + + \ No newline at end of file