diff --git a/.gitignore b/.gitignore index 27ff1a764..4e8c0610b 100644 --- a/.gitignore +++ b/.gitignore @@ -70,3 +70,12 @@ config.mk xgboost *.data build_plugin +dmlc-core +.idea +recommonmark/ +tags +*.iml +*.class +target + +*.swp diff --git a/Makefile b/Makefile index c3becf685..e8bc5f9b8 100644 --- a/Makefile +++ b/Makefile @@ -84,7 +84,7 @@ $(DMLC_CORE)/libdmlc.a: $(RABIT)/lib/$(LIB_RABIT): + cd $(RABIT); make lib/$(LIB_RABIT); cd $(ROOTDIR) -java: java/libxgboost4j.so +jvm: jvm-packages/lib/libxgboost4j.so SRC = $(wildcard src/*.cc src/*/*.cc) ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC)) $(PLUGIN_OBJS) @@ -120,7 +120,8 @@ lib/libxgboost.dll lib/libxgboost.so: $(ALL_DEP) @mkdir -p $(@D) $(CXX) $(CFLAGS) -shared -o $@ $(filter %.o %a, $^) $(LDFLAGS) -java/libxgboost4j.so: java/xgboost4j_wrapper.cpp $(ALL_DEP) +jvm-packages/lib/libxgboost4j.so: jvm-packages/xgboost4j/src/native/xgboost4j.cpp $(ALL_DEP) + @mkdir -p $(@D) $(CXX) $(CFLAGS) $(JAVAINCFLAGS) -shared -o $@ $(filter %.cpp %.o %.a, $^) $(LDFLAGS) xgboost: $(CLI_OBJ) $(ALL_DEP) diff --git a/java/xgboost4j-demo/pom.xml b/java/xgboost4j-demo/pom.xml deleted file mode 100644 index 28c51bc13..000000000 --- a/java/xgboost4j-demo/pom.xml +++ /dev/null @@ -1,36 +0,0 @@ - - - 4.0.0 - org.dmlc - xgboost4j-demo - 1.0 - jar - - UTF-8 - 1.7 - 1.7 - - - - org.dmlc - xgboost4j - 1.1 - - - commons-io - commons-io - 2.4 - - - org.apache.commons - commons-lang3 - 3.4 - - - junit - junit - 4.11 - test - - - \ No newline at end of file diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BasicWalkThrough.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BasicWalkThrough.java deleted file mode 100644 index 0c6529d2c..000000000 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BasicWalkThrough.java +++ /dev/null @@ -1,164 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ -package org.dmlc.xgboost4j.demo; - -import java.io.File; -import java.io.IOException; -import java.io.UnsupportedEncodingException; -import java.util.AbstractMap; -import java.util.AbstractMap.SimpleEntry; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import org.dmlc.xgboost4j.Booster; -import org.dmlc.xgboost4j.DMatrix; -import org.dmlc.xgboost4j.demo.util.DataLoader; -import org.dmlc.xgboost4j.demo.util.Params; -import org.dmlc.xgboost4j.util.Trainer; -import org.dmlc.xgboost4j.util.XGBoostError; - -/** - * a simple example of java wrapper for xgboost - * @author hzx - */ -public class BasicWalkThrough { - public static boolean checkPredicts(float[][] fPredicts, float[][] sPredicts) { - if(fPredicts.length != sPredicts.length) { - return false; - } - - for(int i=0; i> object would be used as paramters - //e.g. - // Map paramMap = new HashMap() { - // { - // put("eta", 1.0); - // put("max_depth", 2); - // put("silent", 1); - // put("objective", "binary:logistic"); - // } - // }; - // Iterable> param = paramMap.entrySet(); - - //or - // List> param = new ArrayList>() { - // { - // add(new SimpleEntry("eta", 1.0)); - // add(new SimpleEntry("max_depth", 2.0)); - // add(new SimpleEntry("silent", 1)); - // add(new SimpleEntry("objective", "binary:logistic")); - // } - // }; - - //we use a util class Params to handle parameters as example - Iterable> param = new Params() { - { - put("eta", 1.0); - put("max_depth", 2); - put("silent", 1); - put("objective", "binary:logistic"); - } - }; - - - - //specify watchList to set evaluation dmats - //note: any Iterable> object would be used as watchList - //e.g. - //an entrySet of Map is good - // Map watchMap = new HashMap<>(); - // watchMap.put("train", trainMat); - // watchMap.put("test", testMat); - // Iterable> watchs = watchMap.entrySet(); - - //we use a List of Entry WatchList as example - List> watchs = new ArrayList<>(); - watchs.add(new SimpleEntry<>("train", trainMat)); - watchs.add(new SimpleEntry<>("test", testMat)); - - //set round - int round = 2; - - //train a boost model - Booster booster = Trainer.train(param, trainMat, round, watchs, null, null); - - //predict - float[][] predicts = booster.predict(testMat); - - //save model to modelPath - File file = new File("./model"); - if(!file.exists()) { - file.mkdirs(); - } - - String modelPath = "./model/xgb.model"; - booster.saveModel(modelPath); - - //dump model - booster.dumpModel("./model/dump.raw.txt", false); - - //dump model with feature map - booster.dumpModel("./model/dump.nice.txt", "../../demo/data/featmap.txt", false); - - //save dmatrix into binary buffer - testMat.saveBinary("./model/dtest.buffer"); - - //reload model and data - Booster booster2 = new Booster(param, "./model/xgb.model"); - DMatrix testMat2 = new DMatrix("./model/dtest.buffer"); - float[][] predicts2 = booster2.predict(testMat2); - - - //check the two predicts - System.out.println(checkPredicts(predicts, predicts2)); - - System.out.println("start build dmatrix from csr sparse data ..."); - //build dmatrix from CSR Sparse Matrix - DataLoader.CSRSparseData spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train"); - - DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data, DMatrix.SparseType.CSR); - trainMat2.setLabel(spData.labels); - - //specify watchList - List> watchs2 = new ArrayList<>(); - watchs2.add(new SimpleEntry<>("train", trainMat2)); - watchs2.add(new SimpleEntry<>("test", testMat2)); - Booster booster3 = Trainer.train(param, trainMat2, round, watchs2, null, null); - float[][] predicts3 = booster3.predict(testMat2); - - //check predicts - System.out.println(checkPredicts(predicts, predicts3)); - } -} diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BoostFromPrediction.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BoostFromPrediction.java deleted file mode 100644 index a81da0c59..000000000 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BoostFromPrediction.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ -package org.dmlc.xgboost4j.demo; - -import java.util.AbstractMap; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import org.dmlc.xgboost4j.Booster; -import org.dmlc.xgboost4j.DMatrix; -import org.dmlc.xgboost4j.demo.util.Params; -import org.dmlc.xgboost4j.util.Trainer; -import org.dmlc.xgboost4j.util.XGBoostError; - -/** - * example for start from a initial base prediction - * @author hzx - */ -public class BoostFromPrediction { - public static void main(String[] args) throws XGBoostError { - System.out.println("start running example to start from a initial prediction"); - - // load file from text file, also binary buffer generated by xgboost4j - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); - - //specify parameters - Params param = new Params() { - { - put("eta", 1.0); - put("max_depth", 2); - put("silent", 1); - put("objective", "binary:logistic"); - } - }; - - //specify watchList - List> watchs = new ArrayList<>(); - watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat)); - watchs.add(new AbstractMap.SimpleEntry<>("test", testMat)); - - //train xgboost for 1 round - Booster booster = Trainer.train(param, trainMat, 1, watchs, null, null); - - float[][] trainPred = booster.predict(trainMat, true); - float[][] testPred = booster.predict(testMat, true); - - trainMat.setBaseMargin(trainPred); - testMat.setBaseMargin(testPred); - - System.out.println("result of running from initial prediction"); - Booster booster2 = Trainer.train(param, trainMat, 1, watchs, null, null); - } -} diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CrossValidation.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CrossValidation.java deleted file mode 100644 index 6dcf917da..000000000 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CrossValidation.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ -package org.dmlc.xgboost4j.demo; - -import java.io.IOException; -import org.dmlc.xgboost4j.DMatrix; -import org.dmlc.xgboost4j.util.Trainer; -import org.dmlc.xgboost4j.demo.util.Params; -import org.dmlc.xgboost4j.util.XGBoostError; - -/** - * an example of cross validation - * @author hzx - */ -public class CrossValidation { - public static void main(String[] args) throws IOException, XGBoostError { - //load train mat - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - - //set params - Params param = new Params() { - { - put("eta", 1.0); - put("max_depth", 3); - put("silent", 1); - put("nthread", 6); - put("objective", "binary:logistic"); - put("gamma", 1.0); - put("eval_metric", "error"); - } - }; - - //do 5-fold cross validation - int round = 2; - int nfold = 5; - //set additional eval_metrics - String[] metrics = null; - - String[] evalHist = Trainer.crossValiation(param, trainMat, round, nfold, metrics, null, null); - } -} diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CustomObjective.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CustomObjective.java deleted file mode 100644 index 2b8c44ecd..000000000 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CustomObjective.java +++ /dev/null @@ -1,175 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ -package org.dmlc.xgboost4j.demo; - -import java.util.AbstractMap; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.dmlc.xgboost4j.Booster; -import org.dmlc.xgboost4j.IEvaluation; -import org.dmlc.xgboost4j.DMatrix; -import org.dmlc.xgboost4j.IObjective; -import org.dmlc.xgboost4j.demo.util.Params; -import org.dmlc.xgboost4j.util.Trainer; -import org.dmlc.xgboost4j.util.XGBoostError; - -/** - * an example user define objective and eval - * NOTE: when you do customized loss function, the default prediction value is margin - * this may make buildin evalution metric not function properly - * for example, we are doing logistic loss, the prediction is score before logistic transformation - * he buildin evaluation error assumes input is after logistic transformation - * Take this in mind when you use the customization, and maybe you need write customized evaluation function - * @author hzx - */ -public class CustomObjective { - /** - * loglikelihoode loss obj function - */ - public static class LogRegObj implements IObjective { - private static final Log logger = LogFactory.getLog(LogRegObj.class); - - /** - * simple sigmoid func - * @param input - * @return - * Note: this func is not concern about numerical stability, only used as example - */ - public float sigmoid(float input) { - float val = (float) (1/(1+Math.exp(-input))); - return val; - } - - public float[][] transform(float[][] predicts) { - int nrow = predicts.length; - float[][] transPredicts = new float[nrow][1]; - - for(int i=0; i getGradient(float[][] predicts, DMatrix dtrain) { - int nrow = predicts.length; - List gradients = new ArrayList<>(); - float[] labels; - try { - labels = dtrain.getLabel(); - } catch (XGBoostError ex) { - logger.error(ex); - return null; - } - float[] grad = new float[nrow]; - float[] hess = new float[nrow]; - - float[][] transPredicts = transform(predicts); - - for(int i=0; i0) { - error++; - } - else if(labels[i]==1f && predicts[i][0]<=0) { - error++; - } - } - - return error/labels.length; - } - } - - public static void main(String[] args) throws XGBoostError { - //load train mat (svmlight format) - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - //load valid mat (svmlight format) - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); - - //set params - //set params - Params param = new Params() { - { - put("eta", 1.0); - put("max_depth", 2); - put("silent", 1); - } - }; - - //set round - int round = 2; - - //specify watchList - List> watchs = new ArrayList<>(); - watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat)); - watchs.add(new AbstractMap.SimpleEntry<>("test", testMat)); - - //user define obj and eval - IObjective obj = new LogRegObj(); - IEvaluation eval = new EvalError(); - - //train a booster - System.out.println("begin to train the booster model"); - Booster booster = Trainer.train(param, trainMat, round, watchs, obj, eval); - } -} diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/ExternalMemory.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/ExternalMemory.java deleted file mode 100644 index b0a9d27dc..000000000 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/ExternalMemory.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ -package org.dmlc.xgboost4j.demo; - -import java.util.AbstractMap; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import org.dmlc.xgboost4j.Booster; -import org.dmlc.xgboost4j.DMatrix; -import org.dmlc.xgboost4j.demo.util.Params; -import org.dmlc.xgboost4j.util.Trainer; -import org.dmlc.xgboost4j.util.XGBoostError; - -/** - * simple example for using external memory version - * @author hzx - */ -public class ExternalMemory { - public static void main(String[] args) throws XGBoostError { - //this is the only difference, add a # followed by a cache prefix name - //several cache file with the prefix will be generated - //currently only support convert from libsvm file - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache"); - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache"); - - //specify parameters - Params param = new Params() { - { - put("eta", 1.0); - put("max_depth", 2); - put("silent", 1); - put("objective", "binary:logistic"); - } - }; - - //performance notice: set nthread to be the number of your real cpu - //some cpu offer two threads per core, for example, a 4 core cpu with 8 threads, in such case set nthread=4 - //param.put("nthread", num_real_cpu); - - //specify watchList - List> watchs = new ArrayList<>(); - watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat)); - watchs.add(new AbstractMap.SimpleEntry<>("test", testMat)); - - //set round - int round = 2; - - //train a boost model - Booster booster = Trainer.train(param, trainMat, round, watchs, null, null); - } -} diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/GeneralizedLinearModel.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/GeneralizedLinearModel.java deleted file mode 100644 index 7d3d717bd..000000000 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/GeneralizedLinearModel.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ -package org.dmlc.xgboost4j.demo; - -import java.util.AbstractMap; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import org.dmlc.xgboost4j.Booster; -import org.dmlc.xgboost4j.DMatrix; -import org.dmlc.xgboost4j.demo.util.CustomEval; -import org.dmlc.xgboost4j.demo.util.Params; -import org.dmlc.xgboost4j.util.Trainer; -import org.dmlc.xgboost4j.util.XGBoostError; - -/** - * this is an example of fit generalized linear model in xgboost - * basically, we are using linear model, instead of tree for our boosters - * @author hzx - */ -public class GeneralizedLinearModel { - public static void main(String[] args) throws XGBoostError { - // load file from text file, also binary buffer generated by xgboost4j - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); - - //specify parameters - //change booster to gblinear, so that we are fitting a linear model - // alpha is the L1 regularizer - //lambda is the L2 regularizer - //you can also set lambda_bias which is L2 regularizer on the bias term - Params param = new Params() { - { - put("alpha", 0.0001); - put("silent", 1); - put("objective", "binary:logistic"); - put("booster", "gblinear"); - } - }; - //normally, you do not need to set eta (step_size) - //XGBoost uses a parallel coordinate descent algorithm (shotgun), - //there could be affection on convergence with parallelization on certain cases - //setting eta to be smaller value, e.g 0.5 can make the optimization more stable - //param.put("eta", "0.5"); - - - //specify watchList - List> watchs = new ArrayList<>(); - watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat)); - watchs.add(new AbstractMap.SimpleEntry<>("test", testMat)); - - //train a booster - int round = 4; - Booster booster = Trainer.train(param, trainMat, round, watchs, null, null); - - float[][] predicts = booster.predict(testMat); - - CustomEval eval = new CustomEval(); - System.out.println("error=" + eval.eval(predicts, testMat)); - } -} diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictFirstNtree.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictFirstNtree.java deleted file mode 100644 index 2bbd1fd6c..000000000 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictFirstNtree.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ -package org.dmlc.xgboost4j.demo; - -import java.util.AbstractMap; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import org.dmlc.xgboost4j.Booster; -import org.dmlc.xgboost4j.DMatrix; -import org.dmlc.xgboost4j.util.Trainer; - -import org.dmlc.xgboost4j.demo.util.CustomEval; -import org.dmlc.xgboost4j.demo.util.Params; -import org.dmlc.xgboost4j.util.XGBoostError; - -/** - * predict first ntree - * @author hzx - */ -public class PredictFirstNtree { - public static void main(String[] args) throws XGBoostError { - // load file from text file, also binary buffer generated by xgboost4j - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); - - //specify parameters - Params param = new Params() { - { - put("eta", 1.0); - put("max_depth", 2); - put("silent", 1); - put("objective", "binary:logistic"); - } - }; - - //specify watchList - List> watchs = new ArrayList<>(); - watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat)); - watchs.add(new AbstractMap.SimpleEntry<>("test", testMat)); - - //train a booster - int round = 3; - Booster booster = Trainer.train(param, trainMat, round, watchs, null, null); - - //predict use 1 tree - float[][] predicts1 = booster.predict(testMat, false, 1); - //by default all trees are used to do predict - float[][] predicts2 = booster.predict(testMat); - - //use a simple evaluation class to check error result - CustomEval eval = new CustomEval(); - System.out.println("error of predicts1: " + eval.eval(predicts1, testMat)); - System.out.println("error of predicts2: " + eval.eval(predicts2, testMat)); - } -} diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictLeafIndices.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictLeafIndices.java deleted file mode 100644 index ede103aeb..000000000 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictLeafIndices.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ -package org.dmlc.xgboost4j.demo; - -import java.util.AbstractMap; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import org.dmlc.xgboost4j.Booster; -import org.dmlc.xgboost4j.DMatrix; -import org.dmlc.xgboost4j.util.Trainer; -import org.dmlc.xgboost4j.demo.util.Params; -import org.dmlc.xgboost4j.util.XGBoostError; - -/** - * predict leaf indices - * @author hzx - */ -public class PredictLeafIndices { - public static void main(String[] args) throws XGBoostError { - // load file from text file, also binary buffer generated by xgboost4j - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); - - //specify parameters - Params param = new Params() { - { - put("eta", 1.0); - put("max_depth", 2); - put("silent", 1); - put("objective", "binary:logistic"); - } - }; - - //specify watchList - List> watchs = new ArrayList<>(); - watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat)); - watchs.add(new AbstractMap.SimpleEntry<>("test", testMat)); - - //train a booster - int round = 3; - Booster booster = Trainer.train(param, trainMat, round, watchs, null, null); - - //predict using first 2 tree - float[][] leafindex = booster.predict(testMat, 2, true); - for(float[] leafs : leafindex) { - System.out.println(Arrays.toString(leafs)); - } - - //predict all trees - leafindex = booster.predict(testMat, 0, true); - for(float[] leafs : leafindex) { - System.out.println(Arrays.toString(leafs)); - } - } -} diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/CustomEval.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/CustomEval.java deleted file mode 100644 index 5f25278d5..000000000 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/CustomEval.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ -package org.dmlc.xgboost4j.demo.util; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.dmlc.xgboost4j.DMatrix; -import org.dmlc.xgboost4j.IEvaluation; -import org.dmlc.xgboost4j.util.XGBoostError; - -/** - * a util evaluation class for examples - * @author hzx - */ -public class CustomEval implements IEvaluation { - private static final Log logger = LogFactory.getLog(CustomEval.class); - - String evalMetric = "custom_error"; - - @Override - public String getMetric() { - return evalMetric; - } - - @Override - public float eval(float[][] predicts, DMatrix dmat) { - float error = 0f; - float[] labels; - try { - labels = dmat.getLabel(); - } catch (XGBoostError ex) { - logger.error(ex); - return -1f; - } - int nrow = predicts.length; - for(int i=0; i0.5) { - error++; - } - else if(labels[i]==1f && predicts[i][0]<=0.5) { - error++; - } - } - - return error/labels.length; - } -} diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/DataLoader.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/DataLoader.java deleted file mode 100644 index 9bad8b372..000000000 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/DataLoader.java +++ /dev/null @@ -1,127 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ -package org.dmlc.xgboost4j.demo.util; - -import java.io.BufferedReader; -import java.io.File; -import java.io.FileInputStream; -import java.io.FileNotFoundException; -import java.io.IOException; -import java.io.InputStreamReader; -import java.io.UnsupportedEncodingException; -import java.util.ArrayList; -import java.util.List; -import org.apache.commons.lang3.ArrayUtils; - -/** - * util class for loading data - * @author hzx - */ -public class DataLoader { - public static class DenseData { - public float[] labels; - public float[] data; - public int nrow; - public int ncol; - } - - public static class CSRSparseData { - public float[] labels; - public float[] data; - public long[] rowHeaders; - public int[] colIndex; - } - - public static DenseData loadCSVFile(String filePath) throws FileNotFoundException, UnsupportedEncodingException, IOException { - DenseData denseData = new DenseData(); - - File f = new File(filePath); - FileInputStream in = new FileInputStream(f); - BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8")); - - denseData.nrow = 0; - denseData.ncol = -1; - String line; - List tlabels = new ArrayList<>(); - List tdata = new ArrayList<>(); - - while((line=reader.readLine()) != null) { - String[] items = line.trim().split(","); - if(items.length==0) { - continue; - } - denseData.nrow++; - if(denseData.ncol == -1) { - denseData.ncol = items.length - 1; - } - - tlabels.add(Float.valueOf(items[items.length-1])); - for(int i=0; i tlabels = new ArrayList<>(); - List tdata = new ArrayList<>(); - List theaders = new ArrayList<>(); - List tindex = new ArrayList<>(); - - File f = new File(filePath); - FileInputStream in = new FileInputStream(f); - BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8")); - - String line; - long rowheader = 0; - theaders.add(rowheader); - while((line=reader.readLine()) != null) { - String[] items = line.trim().split(" "); - if(items.length==0) { - continue; - } - - rowheader += items.length - 1; - theaders.add(rowheader); - tlabels.add(Float.valueOf(items[0])); - - for(int i=1; i>{ - List> params = new ArrayList<>(); - - /** - * put param key-value pair - * @param key - * @param value - */ - public void put(String key, Object value) { - params.add(new AbstractMap.SimpleEntry<>(key, value)); - } - - @Override - public String toString(){ - String paramsInfo = ""; - for(Entry param : params) { - paramsInfo += param.getKey() + ":" + param.getValue() + "\n"; - } - return paramsInfo; - } - - @Override - public Iterator> iterator() { - return params.iterator(); - } -} diff --git a/java/xgboost4j/pom.xml b/java/xgboost4j/pom.xml deleted file mode 100644 index 5e312bf4f..000000000 --- a/java/xgboost4j/pom.xml +++ /dev/null @@ -1,35 +0,0 @@ - - - 4.0.0 - org.dmlc - xgboost4j - 1.1 - jar - - UTF-8 - 1.7 - 1.7 - - - - - org.apache.maven.plugins - maven-javadoc-plugin - 2.10.3 - - - - - - junit - junit - 4.11 - test - - - commons-logging - commons-logging - 1.2 - - - \ No newline at end of file diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/Booster.java b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/Booster.java deleted file mode 100644 index 64c89ae06..000000000 --- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/Booster.java +++ /dev/null @@ -1,484 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ -package org.dmlc.xgboost4j; - -import java.io.BufferedWriter; -import java.io.File; -import java.io.FileNotFoundException; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.OutputStreamWriter; -import java.io.UnsupportedEncodingException; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; - -import org.dmlc.xgboost4j.util.Initializer; -import org.dmlc.xgboost4j.util.ErrorHandle; -import org.dmlc.xgboost4j.util.XGBoostError; -import org.dmlc.xgboost4j.wrapper.XgboostJNI; - - -/** - * Booster for xgboost, similar to the python wrapper xgboost.py - * but custom obj function and eval function not supported at present. - * @author hzx - */ -public final class Booster { - private static final Log logger = LogFactory.getLog(Booster.class); - - long handle = 0; - - //load native library - static { - try { - Initializer.InitXgboost(); - } catch (IOException ex) { - logger.error("load native library failed."); - logger.error(ex); - } - } - - /** - * init Booster from dMatrixs - * @param params parameters - * @param dMatrixs DMatrix array - * @throws org.dmlc.xgboost4j.util.XGBoostError native error - */ - public Booster(Iterable> params, DMatrix[] dMatrixs) throws XGBoostError { - init(dMatrixs); - setParam("seed","0"); - setParams(params); - } - - - - /** - * load model from modelPath - * @param params parameters - * @param modelPath booster modelPath (model generated by booster.saveModel) - * @throws org.dmlc.xgboost4j.util.XGBoostError native error - */ - public Booster(Iterable> params, String modelPath) throws XGBoostError { - init(null); - if(modelPath == null) { - throw new NullPointerException("modelPath : null"); - } - loadModel(modelPath); - setParam("seed","0"); - setParams(params); - } - - - - - private void init(DMatrix[] dMatrixs) throws XGBoostError { - long[] handles = null; - if(dMatrixs != null) { - handles = dMatrixs2handles(dMatrixs); - } - long[] out = new long[1]; - ErrorHandle.checkCall(XgboostJNI.XGBoosterCreate(handles, out)); - - handle = out[0]; - } - - /** - * set parameter - * @param key param name - * @param value param value - * @throws org.dmlc.xgboost4j.util.XGBoostError native error - */ - public final void setParam(String key, String value) throws XGBoostError { - ErrorHandle.checkCall(XgboostJNI.XGBoosterSetParam(handle, key, value)); - } - - /** - * set parameters - * @param params parameters key-value map - * @throws org.dmlc.xgboost4j.util.XGBoostError native error - */ - public void setParams(Iterable> params) throws XGBoostError { - if(params!=null) { - for(Map.Entry entry : params) { - setParam(entry.getKey(), entry.getValue().toString()); - } - } - } - - - /** - * Update (one iteration) - * @param dtrain training data - * @param iter current iteration number - * @throws org.dmlc.xgboost4j.util.XGBoostError native error - */ - public void update(DMatrix dtrain, int iter) throws XGBoostError { - ErrorHandle.checkCall(XgboostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle())); - } - - /** - * update with customize obj func - * @param dtrain training data - * @param iter current iteration number - * @param obj customized objective class - * @throws org.dmlc.xgboost4j.util.XGBoostError native error - */ - public void update(DMatrix dtrain, int iter, IObjective obj) throws XGBoostError { - float[][] predicts = predict(dtrain, true); - List gradients = obj.getGradient(predicts, dtrain); - boost(dtrain, gradients.get(0), gradients.get(1)); - } - - /** - * update with give grad and hess - * @param dtrain training data - * @param grad first order of gradient - * @param hess seconde order of gradient - * @throws org.dmlc.xgboost4j.util.XGBoostError native error - */ - public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError { - if(grad.length != hess.length) { - throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length, hess.length)); - } - ErrorHandle.checkCall(XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad, hess)); - } - - /** - * evaluate with given dmatrixs. - * @param evalMatrixs dmatrixs for evaluation - * @param evalNames name for eval dmatrixs, used for check results - * @param iter current eval iteration - * @return eval information - * @throws org.dmlc.xgboost4j.util.XGBoostError native error - */ - public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError { - long[] handles = dMatrixs2handles(evalMatrixs); - String[] evalInfo = new String[1]; - ErrorHandle.checkCall(XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames, evalInfo)); - return evalInfo[0]; - } - - /** - * evaluate with given customized Evaluation class - * @param evalMatrixs evaluation matrix - * @param evalNames evaluation names - * @param iter number of interations - * @param eval custom evaluator - * @return eval information - * @throws org.dmlc.xgboost4j.util.XGBoostError native error - */ - public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter, IEvaluation eval) throws XGBoostError { - String evalInfo = ""; - for(int i=0; i getFeatureScore() throws XGBoostError { - String[] modelInfos = getDumpInfo(false); - Map featureScore = new HashMap<>(); - for(String tree : modelInfos) { - for(String node : tree.split("\n")) { - String[] array = node.split("\\["); - if(array.length == 1) { - continue; - } - String fid = array[1].split("\\]")[0]; - fid = fid.split("<")[0]; - if(featureScore.containsKey(fid)) { - featureScore.put(fid, 1 + featureScore.get(fid)); - } - else { - featureScore.put(fid, 1); - } - } - } - return featureScore; - } - - - /** - * get importance of each feature - * @param featureMap file to save dumped model info - * @return featureMap key: feature index, value: feature importance score - * @throws org.dmlc.xgboost4j.util.XGBoostError native error - */ - public Map getFeatureScore(String featureMap) throws XGBoostError { - String[] modelInfos = getDumpInfo(featureMap, false); - Map featureScore = new HashMap<>(); - for(String tree : modelInfos) { - for(String node : tree.split("\n")) { - String[] array = node.split("\\["); - if(array.length == 1) { - continue; - } - String fid = array[1].split("\\]")[0]; - fid = fid.split("<")[0]; - if(featureScore.containsKey(fid)) { - featureScore.put(fid, 1 + featureScore.get(fid)); - } - else { - featureScore.put(fid, 1); - } - } - } - return featureScore; - } - - /** - * transfer DMatrix array to handle array (used for native functions) - * @param dmatrixs - * @return handle array for input dmatrixs - */ - private static long[] dMatrixs2handles(DMatrix[] dmatrixs) { - long[] handles = new long[dmatrixs.length]; - for(int i=0; i> params) throws XGBoostError { - dmats = new DMatrix[] {dtrain, dtest}; - booster = new Booster(params, dmats); - names = new String[] {"train", "test"}; - this.dtrain = dtrain; - this.dtest = dtest; - } - - /** - * update one iteration - * @param iter iteration num - * @throws org.dmlc.xgboost4j.util.XGBoostError native error - */ - public void update(int iter) throws XGBoostError { - booster.update(dtrain, iter); - } - - /** - * update one iteration - * @param iter iteration num - * @param obj customized objective - * @throws org.dmlc.xgboost4j.util.XGBoostError native error - */ - public void update(int iter, IObjective obj) throws XGBoostError { - booster.update(dtrain, iter, obj); - } - - /** - * evaluation - * @param iter iteration num - * @return evaluation - * @throws org.dmlc.xgboost4j.util.XGBoostError native error - */ - public String eval(int iter) throws XGBoostError { - return booster.evalSet(dmats, names, iter); - } - - /** - * evaluation - * @param iter iteration num - * @param eval customized eval - * @return evaluation - * @throws org.dmlc.xgboost4j.util.XGBoostError native error - */ - public String eval(int iter, IEvaluation eval) throws XGBoostError { - return booster.evalSet(dmats, names, iter, eval); - } -} diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/ErrorHandle.java b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/ErrorHandle.java deleted file mode 100644 index aad9f6174..000000000 --- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/ErrorHandle.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ -package org.dmlc.xgboost4j.util; - -import java.io.IOException; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.dmlc.xgboost4j.wrapper.XgboostJNI; - -/** - * Error handle for Xgboost. - */ -public class ErrorHandle { - private static final Log logger = LogFactory.getLog(ErrorHandle.class); - - //load native library - static { - try { - Initializer.InitXgboost(); - } catch (IOException ex) { - logger.error("load native library failed."); - logger.error(ex); - } - } - - /** - * Check the return value of C API. - * @param ret return valud of xgboostJNI C API call - * @throws org.dmlc.xgboost4j.util.XGBoostError native error - */ - public static void checkCall(int ret) throws XGBoostError { - if(ret != 0) { - throw new XGBoostError(XgboostJNI.XGBGetLastError()); - } - } -} diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Initializer.java b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Initializer.java deleted file mode 100644 index 5dbbe4b28..000000000 --- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Initializer.java +++ /dev/null @@ -1,92 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ -package org.dmlc.xgboost4j.util; - -import java.io.IOException; -import java.lang.reflect.Field; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; - -/** - * class to load native library - * @author hzx - */ -public class Initializer { - private static final Log logger = LogFactory.getLog(Initializer.class); - - static boolean initialized = false; - public static final String nativePath = "./lib"; - public static final String nativeResourcePath = "/lib/"; - public static final String[] libNames = new String[] {"xgboost4j"}; - - public static synchronized void InitXgboost() throws IOException { - if(initialized == false) { - for(String libName: libNames) { - smartLoad(libName); - } - initialized = true; - } - } - - /** - * load native library, this method will first try to load library from java.library.path, then try to load library in jar package. - * @param libName library path - * @throws IOException exception - */ - private static void smartLoad(String libName) throws IOException { - addNativeDir(nativePath); - try { - System.loadLibrary(libName); - } - catch (UnsatisfiedLinkError e) { - try { - NativeUtils.loadLibraryFromJar(nativeResourcePath + System.mapLibraryName(libName)); - } - catch (IOException e1) { - throw e1; - } - } - } - - /** - * Add libPath to java.library.path, then native library in libPath would be load properly - * @param libPath library path - * @throws IOException exception - */ - public static void addNativeDir(String libPath) throws IOException { - try { - Field field = ClassLoader.class.getDeclaredField("usr_paths"); - field.setAccessible(true); - String[] paths = (String[]) field.get(null); - for (String path : paths) { - if (libPath.equals(path)) { - return; - } - } - String[] tmp = new String[paths.length+1]; - System.arraycopy(paths,0,tmp,0,paths.length); - tmp[paths.length] = libPath; - field.set(null, tmp); - } catch (IllegalAccessException e) { - logger.error(e.getMessage()); - throw new IOException("Failed to get permissions to set library path"); - } catch (NoSuchFieldException e) { - logger.error(e.getMessage()); - throw new IOException("Failed to get field handle to set library path"); - } - } -} diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/NativeUtils.java b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/NativeUtils.java deleted file mode 100644 index 77e299fa2..000000000 --- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/NativeUtils.java +++ /dev/null @@ -1,113 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ -package org.dmlc.xgboost4j.util; - -import java.io.File; -import java.io.FileNotFoundException; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; - - -/** - * Simple library class for working with JNI (Java Native Interface) - *

- * See - * http://adamheinrich.com/2012/how-to-load-native-jni-library-from-jar - *

- * Author Adam Heirnich <adam@adamh.cz>, http://www.adamh.cz - */ -public class NativeUtils { - - /** - * Private constructor - this class will never be instanced - */ - private NativeUtils() { - } - - /** - * Loads library from current JAR archive - *

- * The file from JAR is copied into system temporary directory and then loaded. - * The temporary file is deleted after exiting. - * Method uses String as filename because the pathname is "abstract", not system-dependent. - *

- * The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to {@code path}. - * - * @param path The filename inside JAR as absolute path (beginning with '/'), e.g. /package/File.ext - * @throws IOException If temporary file creation or read/write operation fails - * @throws IllegalArgumentException If source file (param path) does not exist - * @throws IllegalArgumentException If the path is not absolute or if the filename is shorter than three characters - */ - public static void loadLibraryFromJar(String path) throws IOException { - - if (!path.startsWith("/")) { - throw new IllegalArgumentException("The path has to be absolute (start with '/')."); - } - - // Obtain filename from path - String[] parts = path.split("/"); - String filename = (parts.length > 1) ? parts[parts.length - 1] : null; - - // Split filename to prexif and suffix (extension) - String prefix = ""; - String suffix = null; - if (filename != null) { - parts = filename.split("\\.", 2); - prefix = parts[0]; - suffix = (parts.length > 1) ? "."+parts[parts.length - 1] : null; // Thanks, davs! :-) - } - - // Check if the filename is okay - if (filename == null || prefix.length() < 3) { - throw new IllegalArgumentException("The filename has to be at least 3 characters long."); - } - - // Prepare temporary file - File temp = File.createTempFile(prefix, suffix); - temp.deleteOnExit(); - - if (!temp.exists()) { - throw new FileNotFoundException("File " + temp.getAbsolutePath() + " does not exist."); - } - - // Prepare buffer for data copying - byte[] buffer = new byte[1024]; - int readBytes; - - // Open and check input stream - InputStream is = NativeUtils.class.getResourceAsStream(path); - if (is == null) { - throw new FileNotFoundException("File " + path + " was not found inside JAR."); - } - - // Open output stream and copy data between source file in JAR and the temporary file - OutputStream os = new FileOutputStream(temp); - try { - while ((readBytes = is.read(buffer)) != -1) { - os.write(buffer, 0, readBytes); - } - } finally { - // If read/write fails, close streams safely before throwing an exception - os.close(); - is.close(); - } - - // Finally, load the library - System.load(temp.getAbsolutePath()); - } -} diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Trainer.java b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Trainer.java deleted file mode 100644 index 994a8b4ac..000000000 --- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Trainer.java +++ /dev/null @@ -1,238 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ -package org.dmlc.xgboost4j.util; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.dmlc.xgboost4j.IEvaluation; -import org.dmlc.xgboost4j.Booster; -import org.dmlc.xgboost4j.DMatrix; -import org.dmlc.xgboost4j.IObjective; - - -/** - * trainer for xgboost - * @author hzx - */ -public class Trainer { - private static final Log logger = LogFactory.getLog(Trainer.class); - - /** - * Train a booster with given parameters. - * @param params Booster params. - * @param dtrain Data to be trained. - * @param round Number of boosting iterations. - * @param watchs a group of items to be evaluated during training, this allows user to watch performance on the validation set. - * @param obj customized objective (set to null if not used) - * @param eval customized evaluation (set to null if not used) - * @return trained booster - * @throws org.dmlc.xgboost4j.util.XGBoostError native error - */ - public static Booster train(Iterable> params, DMatrix dtrain, int round, - Iterable> watchs, IObjective obj, IEvaluation eval) throws XGBoostError { - - //collect eval matrixs - String[] evalNames; - DMatrix[] evalMats; - List names = new ArrayList<>(); - List mats = new ArrayList<>(); - - for(Entry evalEntry : watchs) { - names.add(evalEntry.getKey()); - mats.add(evalEntry.getValue()); - } - - evalNames = names.toArray(new String[names.size()]); - evalMats = mats.toArray(new DMatrix[mats.size()]); - - //collect all data matrixs - DMatrix[] allMats; - if(evalMats!=null && evalMats.length>0) { - allMats = new DMatrix[evalMats.length+1]; - allMats[0] = dtrain; - System.arraycopy(evalMats, 0, allMats, 1, evalMats.length); - } - else { - allMats = new DMatrix[1]; - allMats[0] = dtrain; - } - - //initialize booster - Booster booster = new Booster(params, allMats); - - //begin to train - for(int iter=0; iter0) { - String evalInfo; - if(eval != null) { - evalInfo = booster.evalSet(evalMats, evalNames, iter, eval); - } - else { - evalInfo = booster.evalSet(evalMats, evalNames, iter); - } - logger.info(evalInfo); - } - } - return booster; - } - - /** - * Cross-validation with given paramaters. - * @param params Booster params. - * @param data Data to be trained. - * @param round Number of boosting iterations. - * @param nfold Number of folds in CV. - * @param metrics Evaluation metrics to be watched in CV. - * @param obj customized objective (set to null if not used) - * @param eval customized evaluation (set to null if not used) - * @return evaluation history - * @throws org.dmlc.xgboost4j.util.XGBoostError native error - */ - public static String[] crossValiation(Iterable> params, DMatrix data, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) throws XGBoostError { - CVPack[] cvPacks = makeNFold(data, nfold, params, metrics); - String[] evalHist = new String[round]; - String[] results = new String[cvPacks.length]; - for(int i=0; i> params, String[] evalMetrics) throws XGBoostError { - List samples = genRandPermutationNums(0, (int) data.rowNum()); - int step = samples.size()/nfold; - int[] testSlice = new int[step]; - int[] trainSlice = new int[samples.size()-step]; - int testid, trainid; - CVPack[] cvPacks = new CVPack[nfold]; - for(int i=0; i(i*step) && j<(i*step+step) && testid genRandPermutationNums(int start, int end) { - List samples = new ArrayList<>(); - for(int i=start; i > cvMap = new HashMap<>(); - String aggResult = results[0].split("\t")[0]; - for(String result : results) { - String[] items = result.split("\t"); - for(int i=1; i()); - } - cvMap.get(key).add(value); - } - } - - for(String key : cvMap.keySet()) { - float value = 0f; - for(Float tvalue : cvMap.get(key)) { - value += tvalue; - } - value /= cvMap.get(key).size(); - aggResult += String.format("\tcv-%s:%f", key, value); - } - - return aggResult; - } -} diff --git a/java/xgboost4j/src/test/java/org/dmlc/xgboost4j/BoosterTest.java b/java/xgboost4j/src/test/java/org/dmlc/xgboost4j/BoosterTest.java deleted file mode 100644 index 20c64b316..000000000 --- a/java/xgboost4j/src/test/java/org/dmlc/xgboost4j/BoosterTest.java +++ /dev/null @@ -1,142 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ -package org.dmlc.xgboost4j; - -import java.util.AbstractMap; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import junit.framework.TestCase; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.dmlc.xgboost4j.util.Trainer; -import org.dmlc.xgboost4j.util.XGBoostError; -import org.junit.Test; - -/** - * test cases for Booster - * @author hzx - */ -public class BoosterTest { - public static class EvalError implements IEvaluation { - private static final Log logger = LogFactory.getLog(EvalError.class); - - String evalMetric = "custom_error"; - - public EvalError() { - } - - @Override - public String getMetric() { - return evalMetric; - } - - @Override - public float eval(float[][] predicts, DMatrix dmat) { - float error = 0f; - float[] labels; - try { - labels = dmat.getLabel(); - } catch (XGBoostError ex) { - logger.error(ex); - return -1f; - } - int nrow = predicts.length; - for(int i=0; i0) { - error++; - } - else if(labels[i]==1f && predicts[i][0]<=0) { - error++; - } - } - - return error/labels.length; - } - } - - @Test - public void testBoosterBasic() throws XGBoostError { - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); - - //set params - Map paramMap = new HashMap() { - { - put("eta", 1.0); - put("max_depth", 2); - put("silent", 1); - put("objective", "binary:logistic"); - } - }; - Iterable> param = paramMap.entrySet(); - - //set watchList - List> watchs = new ArrayList<>(); - watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat)); - watchs.add(new AbstractMap.SimpleEntry<>("test", testMat)); - - //set round - int round = 2; - - //train a boost model - Booster booster = Trainer.train(param, trainMat, round, watchs, null, null); - - //predict raw output - float[][] predicts = booster.predict(testMat, true); - - //eval - IEvaluation eval = new EvalError(); - //error must be less than 0.1 - TestCase.assertTrue(eval.eval(predicts, testMat)<0.1f); - - //test dump model - - } - - /** - * test cross valiation - * @throws XGBoostError - */ - @Test - public void testCV() throws XGBoostError { - //load train mat - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - - //set params - Map param= new HashMap() { - { - put("eta", 1.0); - put("max_depth", 3); - put("silent", 1); - put("nthread", 6); - put("objective", "binary:logistic"); - put("gamma", 1.0); - put("eval_metric", "error"); - } - }; - - //do 5-fold cross validation - int round = 2; - int nfold = 5; - //set additional eval_metrics - String[] metrics = null; - - String[] evalHist = Trainer.crossValiation(param.entrySet(), trainMat, round, nfold, metrics, null, null); - } -} diff --git a/java/xgboost4j/src/test/java/org/dmlc/xgboost4j/DMatrixTest.java b/java/xgboost4j/src/test/java/org/dmlc/xgboost4j/DMatrixTest.java deleted file mode 100644 index 343dd3ed9..000000000 --- a/java/xgboost4j/src/test/java/org/dmlc/xgboost4j/DMatrixTest.java +++ /dev/null @@ -1,102 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ -package org.dmlc.xgboost4j; - -import java.util.Arrays; -import java.util.Random; -import junit.framework.TestCase; -import org.dmlc.xgboost4j.util.XGBoostError; -import org.junit.Test; - -/** - * test cases for DMatrix - * @author hzx - */ -public class DMatrixTest { - - @Test - public void testCreateFromFile() throws XGBoostError { - //create DMatrix from file - DMatrix dmat = new DMatrix("../../demo/data/agaricus.txt.test"); - //get label - float[] labels = dmat.getLabel(); - //check length - TestCase.assertTrue(dmat.rowNum()==labels.length); - //set weights - float[] weights = Arrays.copyOf(labels, labels.length); - dmat.setWeight(weights); - float[] dweights = dmat.getWeight(); - TestCase.assertTrue(Arrays.equals(weights, dweights)); - } - - @Test - public void testCreateFromCSR() throws XGBoostError { - //create Matrix from csr format sparse Matrix and labels - /** - * sparse matrix - * 1 0 2 3 0 - * 4 0 2 3 5 - * 3 1 2 5 0 - */ - float[] data = new float[] {1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5}; - int[] colIndex = new int[] {0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3}; - long[] rowHeaders = new long[] {0, 3, 7, 11}; - DMatrix dmat1 = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR); - //check row num - System.out.println(dmat1.rowNum()); - TestCase.assertTrue(dmat1.rowNum()==3); - //test set label - float[] label1 = new float[] {1, 0, 1}; - dmat1.setLabel(label1); - float[] label2 = dmat1.getLabel(); - TestCase.assertTrue(Arrays.equals(label1, label2)); - } - - @Test - public void testCreateFromDenseMatrix() throws XGBoostError { - //create DMatrix from 10*5 dense matrix - int nrow = 10; - int ncol = 5; - float[] data0 = new float[nrow*ncol]; - //put random nums - Random random = new Random(); - for(int i=0; i -/* Header for class org_dmlc_xgboost4j_wrapper_XgboostJNI */ - -#ifndef _Included_org_dmlc_xgboost4j_wrapper_XgboostJNI -#define _Included_org_dmlc_xgboost4j_wrapper_XgboostJNI -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI - * Method: XGBGetLastError - * Signature: ()Ljava/lang/String; - */ -JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBGetLastError - (JNIEnv *, jclass); - -/* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI - * Method: XGDMatrixCreateFromFile - * Signature: (Ljava/lang/String;I[J)I - */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromFile - (JNIEnv *, jclass, jstring, jint, jlongArray); - -/* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI - * Method: XGDMatrixCreateFromCSR - * Signature: ([J[I[F[J)I - */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSR - (JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jlongArray); - -/* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI - * Method: XGDMatrixCreateFromCSC - * Signature: ([J[I[F[J)I - */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSC - (JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jlongArray); - -/* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI - * Method: XGDMatrixCreateFromMat - * Signature: ([FIIF[J)I - */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromMat - (JNIEnv *, jclass, jfloatArray, jint, jint, jfloat, jlongArray); - -/* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI - * Method: XGDMatrixSliceDMatrix - * Signature: (J[I[J)I - */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSliceDMatrix - (JNIEnv *, jclass, jlong, jintArray, jlongArray); - -/* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI - * Method: XGDMatrixFree - * Signature: (J)I - */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixFree - (JNIEnv *, jclass, jlong); - -/* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI - * Method: XGDMatrixSaveBinary - * Signature: (JLjava/lang/String;I)I - */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSaveBinary - (JNIEnv *, jclass, jlong, jstring, jint); - -/* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI - * Method: XGDMatrixSetFloatInfo - * Signature: (JLjava/lang/String;[F)I - */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetFloatInfo - (JNIEnv *, jclass, jlong, jstring, jfloatArray); - -/* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI - * Method: XGDMatrixSetUIntInfo - * Signature: (JLjava/lang/String;[I)I - */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetUIntInfo - (JNIEnv *, jclass, jlong, jstring, jintArray); - -/* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI - * Method: XGDMatrixSetGroup - * Signature: (J[I)I - */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGroup - (JNIEnv *, jclass, jlong, jintArray); - -/* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI - * Method: XGDMatrixGetFloatInfo - * Signature: (JLjava/lang/String;[[F)I - */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetFloatInfo - (JNIEnv *, jclass, jlong, jstring, jobjectArray); - -/* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI - * Method: XGDMatrixGetUIntInfo - * Signature: (JLjava/lang/String;[[I)I - */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetUIntInfo - (JNIEnv *, jclass, jlong, jstring, jobjectArray); - -/* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI - * Method: XGDMatrixNumRow - * Signature: (J[J)I - */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixNumRow - (JNIEnv *, jclass, jlong, jlongArray); - -/* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI - * Method: XGBoosterCreate - * Signature: ([J[J)I - */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCreate - (JNIEnv *, jclass, jlongArray, jlongArray); - -/* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI - * Method: XGBoosterFree - * Signature: (J)I - */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterFree - (JNIEnv *, jclass, jlong); - -/* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI - * Method: XGBoosterSetParam - * Signature: (JLjava/lang/String;Ljava/lang/String;)I - */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSetParam - (JNIEnv *, jclass, jlong, jstring, jstring); - -/* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI - * Method: XGBoosterUpdateOneIter - * Signature: (JIJ)I - */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdateOneIter - (JNIEnv *, jclass, jlong, jint, jlong); - -/* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI - * Method: XGBoosterBoostOneIter - * Signature: (JJ[F[F)I - */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoostOneIter - (JNIEnv *, jclass, jlong, jlong, jfloatArray, jfloatArray); - -/* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI - * Method: XGBoosterEvalOneIter - * Signature: (JI[J[Ljava/lang/String;[Ljava/lang/String;)I - */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEvalOneIter - (JNIEnv *, jclass, jlong, jint, jlongArray, jobjectArray, jobjectArray); - -/* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI - * Method: XGBoosterPredict - * Signature: (JJIJ[[F)I - */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterPredict - (JNIEnv *, jclass, jlong, jlong, jint, jint, jobjectArray); - -/* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI - * Method: XGBoosterLoadModel - * Signature: (JLjava/lang/String;)I - */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModel - (JNIEnv *, jclass, jlong, jstring); - -/* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI - * Method: XGBoosterSaveModel - * Signature: (JLjava/lang/String;)I - */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSaveModel - (JNIEnv *, jclass, jlong, jstring); - -/* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI - * Method: XGBoosterLoadModelFromBuffer - * Signature: (JJJ)I - */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModelFromBuffer - (JNIEnv *, jclass, jlong, jlong, jlong); - -/* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI - * Method: XGBoosterGetModelRaw - * Signature: (J[Ljava/lang/String;)I - */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGetModelRaw - (JNIEnv *, jclass, jlong, jobjectArray); - -/* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI - * Method: XGBoosterDumpModel - * Signature: (JLjava/lang/String;I[[Ljava/lang/String;)I - */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterDumpModel - (JNIEnv *, jclass, jlong, jstring, jint, jobjectArray); - -#ifdef __cplusplus -} -#endif -#endif diff --git a/java/README.md b/jvm-packages/README.md similarity index 100% rename from java/README.md rename to jvm-packages/README.md diff --git a/jvm-packages/checkstyle-suppressions.xml b/jvm-packages/checkstyle-suppressions.xml new file mode 100644 index 000000000..21550e139 --- /dev/null +++ b/jvm-packages/checkstyle-suppressions.xml @@ -0,0 +1,33 @@ + + + + + + + + + diff --git a/jvm-packages/checkstyle.xml b/jvm-packages/checkstyle.xml new file mode 100644 index 000000000..9583ec282 --- /dev/null +++ b/jvm-packages/checkstyle.xml @@ -0,0 +1,169 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/java/create_wrap.bat b/jvm-packages/create_jni.bat similarity index 98% rename from java/create_wrap.bat rename to jvm-packages/create_jni.bat index ce4d99327..cbc0681c1 100644 --- a/java/create_wrap.bat +++ b/jvm-packages/create_jni.bat @@ -17,4 +17,4 @@ exit :end echo "source library not found, please build it first from ..\windows\xgboost.sln" pause - exit \ No newline at end of file + exit diff --git a/java/create_wrap.sh b/jvm-packages/create_jni.sh similarity index 81% rename from java/create_wrap.sh rename to jvm-packages/create_jni.sh index fb3b1f149..13e6a8556 100755 --- a/java/create_wrap.sh +++ b/jvm-packages/create_jni.sh @@ -16,8 +16,8 @@ if [ $(uname) == "Darwin" ]; then fi cd .. -make java no_omp=${dis_omp} -cd java +make jvm no_omp=${dis_omp} +cd jvm-packages echo "move native lib" libPath="xgboost4j/src/main/resources/lib" @@ -26,7 +26,7 @@ if [ ! -d "$libPath" ]; then fi rm -f xgboost4j/src/main/resources/lib/libxgboost4j.${dl} -mv libxgboost4j.so xgboost4j/src/main/resources/lib/libxgboost4j.${dl} +mv lib/libxgboost4j.so xgboost4j/src/main/resources/lib/libxgboost4j.${dl} popd > /dev/null echo "complete" diff --git a/java/doc/xgboost4j.md b/jvm-packages/doc/xgboost4j.md similarity index 100% rename from java/doc/xgboost4j.md rename to jvm-packages/doc/xgboost4j.md diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml new file mode 100644 index 000000000..5ec221175 --- /dev/null +++ b/jvm-packages/pom.xml @@ -0,0 +1,117 @@ + + + 4.0.0 + + org.dmlc + xgboostjvm + 0.1 + pom + + UTF-8 + UTF-8 + 1.7 + 1.7 + 3.3.9 + 2.11.7 + 2.11 + + + xgboost4j + xgboost4j-demo + + + + + org.scalastyle + scalastyle-maven-plugin + 0.8.0 + + false + true + true + ${basedir}/src/main/scala + ${basedir}/src/test/scala + scalastyle-config.xml + UTF-8 + + + + checkstyle + validate + + check + + + + + + org.apache.maven.plugins + maven-checkstyle-plugin + 2.17 + + checkstyle.xml + true + + + + checkstyle + validate + + check + + + + + + net.alchim31.maven + scala-maven-plugin + 3.2.2 + + + compile + + compile + + compile + + + test-compile + + testCompile + + test-compile + + + process-resources + + compile + + + + + + org.apache.maven.plugins + maven-surefire-plugin + 2.19.1 + + -Djava.library.path=lib/ + + + + + + + commons-logging + commons-logging + 1.2 + + + org.scalatest + scalatest_${scala.binary.version} + 2.2.6 + test + + + diff --git a/jvm-packages/scalastyle-config.xml b/jvm-packages/scalastyle-config.xml new file mode 100644 index 000000000..27bb4fa8a --- /dev/null +++ b/jvm-packages/scalastyle-config.xml @@ -0,0 +1,291 @@ + + + + + Scalastyle standard configuration + + + + + + + + + + + + + + + + + + + + + + + + true + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + ARROW, EQUALS, ELSE, TRY, CATCH, FINALLY, LARROW, RARROW + + + + + + ARROW, EQUALS, COMMA, COLON, IF, ELSE, DO, WHILE, FOR, MATCH, TRY, CATCH, FINALLY, LARROW, RARROW + + + + + + + + + ^FunSuite[A-Za-z]*$ + Tests must extend org.apache.spark.SparkFunSuite instead. + + + + + ^println$ + + + + + @VisibleForTesting + + + + + Runtime\.getRuntime\.addShutdownHook + + + + + mutable\.SynchronizedBuffer + + + + + Class\.forName + + + + + + JavaConversions + Instead of importing implicits in scala.collection.JavaConversions._, import + scala.collection.JavaConverters._ and use .asScala / .asJava methods + + + + + java,scala,3rdParty,spark + javax?\..* + scala\..* + (?!org\.apache\.spark\.).* + org\.apache\.spark\..* + + + + + + COMMA + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 800> + + + + + 30 + + + + + 10 + + + + + 50 + + + + + + + + + + + -1,0,1,2,3 + + + diff --git a/java/xgboost4j-demo/LICENSE b/jvm-packages/xgboost4j-demo/LICENSE similarity index 100% rename from java/xgboost4j-demo/LICENSE rename to jvm-packages/xgboost4j-demo/LICENSE diff --git a/java/xgboost4j-demo/README.md b/jvm-packages/xgboost4j-demo/README.md similarity index 100% rename from java/xgboost4j-demo/README.md rename to jvm-packages/xgboost4j-demo/README.md diff --git a/jvm-packages/xgboost4j-demo/pom.xml b/jvm-packages/xgboost4j-demo/pom.xml new file mode 100644 index 000000000..d8e679b78 --- /dev/null +++ b/jvm-packages/xgboost4j-demo/pom.xml @@ -0,0 +1,26 @@ + + + 4.0.0 + + org.dmlc + xgboostjvm + 0.1 + + xgboost4j-demo + 0.1 + jar + + + org.dmlc + xgboost4j + 0.1 + + + org.apache.commons + commons-lang3 + 3.4 + + + \ No newline at end of file diff --git a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/BasicWalkThrough.java b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/BasicWalkThrough.java new file mode 100644 index 000000000..af5dd8a86 --- /dev/null +++ b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/BasicWalkThrough.java @@ -0,0 +1,120 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package ml.dmlc.xgboost4j.demo; + +import java.io.File; +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; + +import ml.dmlc.xgboost4j.Booster; +import ml.dmlc.xgboost4j.DMatrix; +import ml.dmlc.xgboost4j.XGBoost; +import ml.dmlc.xgboost4j.XGBoostError; +import ml.dmlc.xgboost4j.demo.util.DataLoader; + +/** + * a simple example of java wrapper for xgboost + * + * @author hzx + */ +public class BasicWalkThrough { + public static boolean checkPredicts(float[][] fPredicts, float[][] sPredicts) { + if (fPredicts.length != sPredicts.length) { + return false; + } + + for (int i = 0; i < fPredicts.length; i++) { + if (!Arrays.equals(fPredicts[i], sPredicts[i])) { + return false; + } + } + + return true; + } + + + public static void main(String[] args) throws IOException, XGBoostError { + // load file from text file, also binary buffer generated by xgboost4j + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + + HashMap params = new HashMap(); + params.put("eta", 1.0); + params.put("max_depth", 2); + params.put("silent", 1); + params.put("objective", "binary:logistic"); + + + HashMap watches = new HashMap(); + watches.put("train", trainMat); + watches.put("test", testMat); + + //set round + int round = 2; + + //train a boost model + Booster booster = XGBoost.train(params, trainMat, round, watches, null, null); + + //predict + float[][] predicts = booster.predict(testMat); + + //save model to modelPath + File file = new File("./model"); + if (!file.exists()) { + file.mkdirs(); + } + + String modelPath = "./model/xgb.model"; + booster.saveModel(modelPath); + + //dump model + booster.dumpModel("./model/dump.raw.txt", false); + + //dump model with feature map + booster.dumpModel("./model/dump.nice.txt", "../../demo/data/featmap.txt", false); + + //save dmatrix into binary buffer + testMat.saveBinary("./model/dtest.buffer"); + + //reload model and data + Booster booster2 = XGBoost.loadBoostModel(params, "./model/xgb.model"); + DMatrix testMat2 = new DMatrix("./model/dtest.buffer"); + float[][] predicts2 = booster2.predict(testMat2); + + + //check the two predicts + System.out.println(checkPredicts(predicts, predicts2)); + + System.out.println("start build dmatrix from csr sparse data ..."); + //build dmatrix from CSR Sparse Matrix + DataLoader.CSRSparseData spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train"); + + DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data, + DMatrix.SparseType.CSR); + trainMat2.setLabel(spData.labels); + + //specify watchList + HashMap watches2 = new HashMap(); + watches2.put("train", trainMat2); + watches2.put("test", testMat2); + Booster booster3 = XGBoost.train(params, trainMat2, round, watches2, null, null); + float[][] predicts3 = booster3.predict(testMat2); + + //check predicts + System.out.println(checkPredicts(predicts, predicts3)); + } +} diff --git a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/BoostFromPrediction.java b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/BoostFromPrediction.java new file mode 100644 index 000000000..335efc2d7 --- /dev/null +++ b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/BoostFromPrediction.java @@ -0,0 +1,62 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package ml.dmlc.xgboost4j.demo; + +import java.util.HashMap; + +import ml.dmlc.xgboost4j.Booster; +import ml.dmlc.xgboost4j.DMatrix; +import ml.dmlc.xgboost4j.XGBoost; +import ml.dmlc.xgboost4j.XGBoostError; + +/** + * example for start from a initial base prediction + * + * @author hzx + */ +public class BoostFromPrediction { + public static void main(String[] args) throws XGBoostError { + System.out.println("start running example to start from a initial prediction"); + + // load file from text file, also binary buffer generated by xgboost4j + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + + //specify parameters + HashMap params = new HashMap(); + params.put("eta", 1.0); + params.put("max_depth", 2); + params.put("silent", 1); + params.put("objective", "binary:logistic"); + + //specify watchList + HashMap watches = new HashMap(); + watches.put("train", trainMat); + watches.put("test", testMat); + + //train xgboost for 1 round + Booster booster = XGBoost.train(params, trainMat, 1, watches, null, null); + + float[][] trainPred = booster.predict(trainMat, true); + float[][] testPred = booster.predict(testMat, true); + + trainMat.setBaseMargin(trainPred); + testMat.setBaseMargin(testPred); + + System.out.println("result of running from initial prediction"); + Booster booster2 = XGBoost.train(params, trainMat, 1, watches, null, null); + } +} diff --git a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/CrossValidation.java b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/CrossValidation.java new file mode 100644 index 000000000..115b1dc5b --- /dev/null +++ b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/CrossValidation.java @@ -0,0 +1,54 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package ml.dmlc.xgboost4j.demo; + +import java.io.IOException; +import java.util.HashMap; + +import ml.dmlc.xgboost4j.DMatrix; +import ml.dmlc.xgboost4j.XGBoost; +import ml.dmlc.xgboost4j.XGBoostError; + +/** + * an example of cross validation + * + * @author hzx + */ +public class CrossValidation { + public static void main(String[] args) throws IOException, XGBoostError { + //load train mat + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + + //set params + HashMap params = new HashMap(); + + params.put("eta", 1.0); + params.put("max_depth", 3); + params.put("silent", 1); + params.put("nthread", 6); + params.put("objective", "binary:logistic"); + params.put("gamma", 1.0); + params.put("eval_metric", "error"); + + //do 5-fold cross validation + int round = 2; + int nfold = 5; + //set additional eval_metrics + String[] metrics = null; + + String[] evalHist = XGBoost.crossValiation(params, trainMat, round, nfold, metrics, null, null); + } +} diff --git a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/CustomObjective.java b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/CustomObjective.java new file mode 100644 index 000000000..be09fd701 --- /dev/null +++ b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/CustomObjective.java @@ -0,0 +1,167 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package ml.dmlc.xgboost4j.demo; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; + +import ml.dmlc.xgboost4j.*; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +/** + * an example user define objective and eval + * NOTE: when you do customized loss function, the default prediction value is margin + * this may make buildin evalution metric not function properly + * for example, we are doing logistic loss, the prediction is score before logistic transformation + * he buildin evaluation error assumes input is after logistic transformation + * Take this in mind when you use the customization, and maybe you need write customized evaluation + * function + * + * @author hzx + */ +public class CustomObjective { + /** + * loglikelihoode loss obj function + */ + public static class LogRegObj implements IObjective { + private static final Log logger = LogFactory.getLog(LogRegObj.class); + + /** + * simple sigmoid func + * + * @param input + * @return Note: this func is not concern about numerical stability, only used as example + */ + public float sigmoid(float input) { + float val = (float) (1 / (1 + Math.exp(-input))); + return val; + } + + public float[][] transform(float[][] predicts) { + int nrow = predicts.length; + float[][] transPredicts = new float[nrow][1]; + + for (int i = 0; i < nrow; i++) { + transPredicts[i][0] = sigmoid(predicts[i][0]); + } + + return transPredicts; + } + + @Override + public List getGradient(float[][] predicts, DMatrix dtrain) { + int nrow = predicts.length; + List gradients = new ArrayList(); + float[] labels; + try { + labels = dtrain.getLabel(); + } catch (XGBoostError ex) { + logger.error(ex); + return null; + } + float[] grad = new float[nrow]; + float[] hess = new float[nrow]; + + float[][] transPredicts = transform(predicts); + + for (int i = 0; i < nrow; i++) { + float predict = transPredicts[i][0]; + grad[i] = predict - labels[i]; + hess[i] = predict * (1 - predict); + } + + gradients.add(grad); + gradients.add(hess); + return gradients; + } + } + + /** + * user defined eval function. + * NOTE: when you do customized loss function, the default prediction value is margin + * this may make buildin evalution metric not function properly + * for example, we are doing logistic loss, the prediction is score before logistic transformation + * the buildin evaluation error assumes input is after logistic transformation + * Take this in mind when you use the customization, and maybe you need write customized + * evaluation function + */ + public static class EvalError implements IEvaluation { + private static final Log logger = LogFactory.getLog(EvalError.class); + + String evalMetric = "custom_error"; + + public EvalError() { + } + + @Override + public String getMetric() { + return evalMetric; + } + + @Override + public float eval(float[][] predicts, DMatrix dmat) { + float error = 0f; + float[] labels; + try { + labels = dmat.getLabel(); + } catch (XGBoostError ex) { + logger.error(ex); + return -1f; + } + int nrow = predicts.length; + for (int i = 0; i < nrow; i++) { + if (labels[i] == 0f && predicts[i][0] > 0) { + error++; + } else if (labels[i] == 1f && predicts[i][0] <= 0) { + error++; + } + } + + return error / labels.length; + } + } + + public static void main(String[] args) throws XGBoostError { + //load train mat (svmlight format) + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + //load valid mat (svmlight format) + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + + HashMap params = new HashMap(); + params.put("eta", 1.0); + params.put("max_depth", 2); + params.put("silent", 1); + + + //set round + int round = 2; + + //specify watchList + HashMap watches = new HashMap(); + watches.put("train", trainMat); + watches.put("test", testMat); + + //user define obj and eval + IObjective obj = new LogRegObj(); + IEvaluation eval = new EvalError(); + + //train a booster + System.out.println("begin to train the booster model"); + Booster booster = XGBoost.train(params, trainMat, round, watches, obj, eval); + } +} diff --git a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/ExternalMemory.java b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/ExternalMemory.java new file mode 100644 index 000000000..095382953 --- /dev/null +++ b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/ExternalMemory.java @@ -0,0 +1,61 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package ml.dmlc.xgboost4j.demo; + +import java.util.HashMap; + +import ml.dmlc.xgboost4j.Booster; +import ml.dmlc.xgboost4j.DMatrix; +import ml.dmlc.xgboost4j.XGBoost; +import ml.dmlc.xgboost4j.XGBoostError; + +/** + * simple example for using external memory version + * + * @author hzx + */ +public class ExternalMemory { + public static void main(String[] args) throws XGBoostError { + //this is the only difference, add a # followed by a cache prefix name + //several cache file with the prefix will be generated + //currently only support convert from libsvm file + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache"); + + //specify parameters + HashMap params = new HashMap(); + params.put("eta", 1.0); + params.put("max_depth", 2); + params.put("silent", 1); + params.put("objective", "binary:logistic"); + + //performance notice: set nthread to be the number of your real cpu + //some cpu offer two threads per core, for example, a 4 core cpu with 8 threads, in such case + // set nthread=4 + //param.put("nthread", num_real_cpu); + + //specify watchList + HashMap watches = new HashMap(); + watches.put("train", trainMat); + watches.put("test", testMat); + + //set round + int round = 2; + + //train a boost model + Booster booster = XGBoost.train(params, trainMat, round, watches, null, null); + } +} diff --git a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/GeneralizedLinearModel.java b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/GeneralizedLinearModel.java new file mode 100644 index 000000000..8fae69032 --- /dev/null +++ b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/GeneralizedLinearModel.java @@ -0,0 +1,70 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package ml.dmlc.xgboost4j.demo; + +import ml.dmlc.xgboost4j.Booster; +import ml.dmlc.xgboost4j.DMatrix; +import ml.dmlc.xgboost4j.XGBoost; +import ml.dmlc.xgboost4j.XGBoostError; +import ml.dmlc.xgboost4j.demo.util.CustomEval; + +import java.util.HashMap; + +/** + * this is an example of fit generalized linear model in xgboost + * basically, we are using linear model, instead of tree for our boosters + * + * @author hzx + */ +public class GeneralizedLinearModel { + public static void main(String[] args) throws XGBoostError { + // load file from text file, also binary buffer generated by xgboost4j + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + + //specify parameters + //change booster to gblinear, so that we are fitting a linear model + // alpha is the L1 regularizer + //lambda is the L2 regularizer + //you can also set lambda_bias which is L2 regularizer on the bias term + HashMap params = new HashMap(); + params.put("alpha", 0.0001); + params.put("silent", 1); + params.put("objective", "binary:logistic"); + params.put("booster", "gblinear"); + + //normally, you do not need to set eta (step_size) + //XGBoost uses a parallel coordinate descent algorithm (shotgun), + //there could be affection on convergence with parallelization on certain cases + //setting eta to be smaller value, e.g 0.5 can make the optimization more stable + //param.put("eta", "0.5"); + + + //specify watchList + HashMap watches = new HashMap(); + watches.put("train", trainMat); + watches.put("test", testMat); + + //train a booster + int round = 4; + Booster booster = XGBoost.train(params, trainMat, round, watches, null, null); + + float[][] predicts = booster.predict(testMat); + + CustomEval eval = new CustomEval(); + System.out.println("error=" + eval.eval(predicts, testMat)); + } +} diff --git a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/PredictFirstNtree.java b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/PredictFirstNtree.java new file mode 100644 index 000000000..defa437d3 --- /dev/null +++ b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/PredictFirstNtree.java @@ -0,0 +1,66 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package ml.dmlc.xgboost4j.demo; + +import java.util.HashMap; + +import ml.dmlc.xgboost4j.Booster; +import ml.dmlc.xgboost4j.DMatrix; +import ml.dmlc.xgboost4j.XGBoost; +import ml.dmlc.xgboost4j.XGBoostError; +import ml.dmlc.xgboost4j.demo.util.CustomEval; + +/** + * predict first ntree + * + * @author hzx + */ +public class PredictFirstNtree { + public static void main(String[] args) throws XGBoostError { + // load file from text file, also binary buffer generated by xgboost4j + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + + //specify parameters + HashMap params = new HashMap(); + + params.put("eta", 1.0); + params.put("max_depth", 2); + params.put("silent", 1); + params.put("objective", "binary:logistic"); + + + //specify watchList + HashMap watches = new HashMap(); + watches.put("train", trainMat); + watches.put("test", testMat); + + + //train a booster + int round = 3; + Booster booster = XGBoost.train(params, trainMat, round, watches, null, null); + + //predict use 1 tree + float[][] predicts1 = booster.predict(testMat, false, 1); + //by default all trees are used to do predict + float[][] predicts2 = booster.predict(testMat); + + //use a simple evaluation class to check error result + CustomEval eval = new CustomEval(); + System.out.println("error of predicts1: " + eval.eval(predicts1, testMat)); + System.out.println("error of predicts2: " + eval.eval(predicts2, testMat)); + } +} diff --git a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/PredictLeafIndices.java b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/PredictLeafIndices.java new file mode 100644 index 000000000..d18987292 --- /dev/null +++ b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/PredictLeafIndices.java @@ -0,0 +1,66 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package ml.dmlc.xgboost4j.demo; + +import java.util.Arrays; +import java.util.HashMap; + +import ml.dmlc.xgboost4j.Booster; +import ml.dmlc.xgboost4j.DMatrix; +import ml.dmlc.xgboost4j.XGBoost; +import ml.dmlc.xgboost4j.XGBoostError; + +/** + * predict leaf indices + * + * @author hzx + */ +public class PredictLeafIndices { + public static void main(String[] args) throws XGBoostError { + // load file from text file, also binary buffer generated by xgboost4j + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + + //specify parameters + HashMap params = new HashMap(); + params.put("eta", 1.0); + params.put("max_depth", 2); + params.put("silent", 1); + params.put("objective", "binary:logistic"); + + //specify watchList + HashMap watches = new HashMap(); + watches.put("train", trainMat); + watches.put("test", testMat); + + + //train a booster + int round = 3; + Booster booster = XGBoost.train(params, trainMat, round, watches, null, null); + + //predict using first 2 tree + float[][] leafindex = booster.predict(testMat, 2, true); + for (float[] leafs : leafindex) { + System.out.println(Arrays.toString(leafs)); + } + + //predict all trees + leafindex = booster.predict(testMat, 0, true); + for (float[] leafs : leafindex) { + System.out.println(Arrays.toString(leafs)); + } + } +} diff --git a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/util/CustomEval.java b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/util/CustomEval.java new file mode 100644 index 000000000..31e841b03 --- /dev/null +++ b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/util/CustomEval.java @@ -0,0 +1,60 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package ml.dmlc.xgboost4j.demo.util; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import ml.dmlc.xgboost4j.DMatrix; +import ml.dmlc.xgboost4j.IEvaluation; +import ml.dmlc.xgboost4j.XGBoostError; + +/** + * a util evaluation class for examples + * + * @author hzx + */ +public class CustomEval implements IEvaluation { + private static final Log logger = LogFactory.getLog(CustomEval.class); + + String evalMetric = "custom_error"; + + @Override + public String getMetric() { + return evalMetric; + } + + @Override + public float eval(float[][] predicts, DMatrix dmat) { + float error = 0f; + float[] labels; + try { + labels = dmat.getLabel(); + } catch (XGBoostError ex) { + logger.error(ex); + return -1f; + } + int nrow = predicts.length; + for (int i = 0; i < nrow; i++) { + if (labels[i] == 0f && predicts[i][0] > 0.5) { + error++; + } else if (labels[i] == 1f && predicts[i][0] <= 0.5) { + error++; + } + } + + return error / labels.length; + } +} diff --git a/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/util/DataLoader.java b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/util/DataLoader.java new file mode 100644 index 000000000..0dcaca8c2 --- /dev/null +++ b/jvm-packages/xgboost4j-demo/src/main/java/ml/dmlc/xgboost4j/demo/util/DataLoader.java @@ -0,0 +1,123 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package ml.dmlc.xgboost4j.demo.util; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.*; +import java.util.ArrayList; +import java.util.List; + +/** + * util class for loading data + * + * @author hzx + */ +public class DataLoader { + public static class DenseData { + public float[] labels; + public float[] data; + public int nrow; + public int ncol; + } + + public static class CSRSparseData { + public float[] labels; + public float[] data; + public long[] rowHeaders; + public int[] colIndex; + } + + public static DenseData loadCSVFile(String filePath) throws IOException { + DenseData denseData = new DenseData(); + + File f = new File(filePath); + FileInputStream in = new FileInputStream(f); + BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8")); + + denseData.nrow = 0; + denseData.ncol = -1; + String line; + List tlabels = new ArrayList<>(); + List tdata = new ArrayList<>(); + + while ((line = reader.readLine()) != null) { + String[] items = line.trim().split(","); + if (items.length == 0) { + continue; + } + denseData.nrow++; + if (denseData.ncol == -1) { + denseData.ncol = items.length - 1; + } + + tlabels.add(Float.valueOf(items[items.length - 1])); + for (int i = 0; i < items.length - 1; i++) { + tdata.add(Float.valueOf(items[i])); + } + } + + reader.close(); + in.close(); + + denseData.labels = ArrayUtils.toPrimitive(tlabels.toArray(new Float[tlabels.size()])); + denseData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata.size()])); + + return denseData; + } + + public static CSRSparseData loadSVMFile(String filePath) throws IOException { + CSRSparseData spData = new CSRSparseData(); + + List tlabels = new ArrayList<>(); + List tdata = new ArrayList<>(); + List theaders = new ArrayList<>(); + List tindex = new ArrayList<>(); + + File f = new File(filePath); + FileInputStream in = new FileInputStream(f); + BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8")); + + String line; + long rowheader = 0; + theaders.add(rowheader); + while ((line = reader.readLine()) != null) { + String[] items = line.trim().split(" "); + if (items.length == 0) { + continue; + } + + rowheader += items.length - 1; + theaders.add(rowheader); + tlabels.add(Float.valueOf(items[0])); + + for (int i = 1; i < items.length; i++) { + String[] tup = items[i].split(":"); + assert tup.length == 2; + + tdata.add(Float.valueOf(tup[1])); + tindex.add(Integer.valueOf(tup[0])); + } + } + + spData.labels = ArrayUtils.toPrimitive(tlabels.toArray(new Float[tlabels.size()])); + spData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata.size()])); + spData.colIndex = ArrayUtils.toPrimitive(tindex.toArray(new Integer[tindex.size()])); + spData.rowHeaders = ArrayUtils.toPrimitive(theaders.toArray(new Long[theaders.size()])); + + return spData; + } +} diff --git a/java/xgboost4j/LICENSE b/jvm-packages/xgboost4j/LICENSE similarity index 100% rename from java/xgboost4j/LICENSE rename to jvm-packages/xgboost4j/LICENSE diff --git a/jvm-packages/xgboost4j/pom.xml b/jvm-packages/xgboost4j/pom.xml new file mode 100644 index 000000000..271d918f9 --- /dev/null +++ b/jvm-packages/xgboost4j/pom.xml @@ -0,0 +1,35 @@ + + + 4.0.0 + + org.dmlc + xgboostjvm + 0.1 + + xgboost4j + 0.1 + jar + + + + org.apache.maven.plugins + maven-javadoc-plugin + 2.10.3 + + protected + true + + + + + + + junit + junit + 4.11 + test + + + diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/Booster.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/Booster.java new file mode 100644 index 000000000..e234fef60 --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/Booster.java @@ -0,0 +1,153 @@ +package ml.dmlc.xgboost4j; + +import java.io.IOException; +import java.util.Map; + +public interface Booster { + + /** + * set parameter + * + * @param key param name + * @param value param value + */ + void setParam(String key, String value) throws XGBoostError; + + /** + * set parameters + * + * @param params parameters key-value map + */ + void setParams(Map params) throws XGBoostError; + + /** + * Update (one iteration) + * + * @param dtrain training data + * @param iter current iteration number + */ + void update(DMatrix dtrain, int iter) throws XGBoostError; + + /** + * update with customize obj func + * + * @param dtrain training data + * @param obj customized objective class + */ + void update(DMatrix dtrain, IObjective obj) throws XGBoostError; + + /** + * update with give grad and hess + * + * @param dtrain training data + * @param grad first order of gradient + * @param hess seconde order of gradient + */ + void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError; + + /** + * evaluate with given dmatrixs. + * + * @param evalMatrixs dmatrixs for evaluation + * @param evalNames name for eval dmatrixs, used for check results + * @param iter current eval iteration + * @return eval information + */ + String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError; + + /** + * evaluate with given customized Evaluation class + * + * @param evalMatrixs evaluation matrix + * @param evalNames evaluation names + * @param eval custom evaluator + * @return eval information + */ + String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval) throws XGBoostError; + + /** + * Predict with data + * + * @param data dmatrix storing the input + * @return predict result + */ + float[][] predict(DMatrix data) throws XGBoostError; + + + /** + * Predict with data + * + * @param data dmatrix storing the input + * @param outPutMargin Whether to output the raw untransformed margin value. + * @return predict result + */ + float[][] predict(DMatrix data, boolean outPutMargin) throws XGBoostError; + + + /** + * Predict with data + * + * @param data dmatrix storing the input + * @param outPutMargin Whether to output the raw untransformed margin value. + * @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees). + * @return predict result + */ + float[][] predict(DMatrix data, boolean outPutMargin, int treeLimit) throws XGBoostError; + + + /** + * Predict with data + * @param data dmatrix storing the input + * @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees). + * @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees), + * nsample = data.numRow with each record indicating the predicted leaf index of + * each sample in each tree. Note that the leaf index of a tree is unique per + * tree, so you may find leaf 1 in both tree 1 and tree 0. + * @return predict result + * @throws XGBoostError native error + */ + float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError; + + /** + * save model to modelPath + * + * @param modelPath model path + */ + void saveModel(String modelPath) throws XGBoostError; + + /** + * Dump model into a text file. + * + * @param modelPath file to save dumped model info + * @param withStats bool Controls whether the split statistics are output. + */ + void dumpModel(String modelPath, boolean withStats) throws IOException, XGBoostError; + + /** + * Dump model into a text file. + * + * @param modelPath file to save dumped model info + * @param featureMap featureMap file + * @param withStats bool + * Controls whether the split statistics are output. + */ + void dumpModel(String modelPath, String featureMap, boolean withStats) + throws IOException, XGBoostError; + + /** + * get importance of each feature + * + * @return featureMap key: feature index, value: feature importance score + */ + Map getFeatureScore() throws XGBoostError ; + + /** + * get importance of each feature + * + * @param featureMap file to save dumped model info + * @return featureMap key: feature index, value: feature importance score + */ + Map getFeatureScore(String featureMap) throws XGBoostError; + + void dispose(); +} diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/DMatrix.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/DMatrix.java new file mode 100644 index 000000000..4b498caf1 --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/DMatrix.java @@ -0,0 +1,256 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package ml.dmlc.xgboost4j; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import java.io.IOException; + +/** + * DMatrix for xgboost, similar to the python wrapper xgboost.py + * + * @author hzx + */ +public class DMatrix { + private static final Log logger = LogFactory.getLog(DMatrix.class); + private long handle = 0; + + //load native library + static { + try { + NativeLibLoader.initXgBoost(); + } catch (IOException ex) { + logger.error("load native library failed."); + logger.error(ex); + } + } + + /** + * sparse matrix type (CSR or CSC) + */ + public static enum SparseType { + CSR, + CSC; + } + + public DMatrix(String dataPath) throws XGBoostError { + if (dataPath == null) { + throw new NullPointerException("dataPath: null"); + } + long[] out = new long[1]; + JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromFile(dataPath, 1, out)); + handle = out[0]; + } + + public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError { + long[] out = new long[1]; + if (st == SparseType.CSR) { + JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromCSR(headers, indices, data, out)); + } else if (st == SparseType.CSC) { + JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromCSC(headers, indices, data, out)); + } else { + throw new UnknownError("unknow sparsetype"); + } + handle = out[0]; + } + + /** + * create DMatrix from dense matrix + * + * @param data data values + * @param nrow number of rows + * @param ncol number of columns + * @throws XGBoostError native error + */ + public DMatrix(float[] data, int nrow, int ncol) throws XGBoostError { + long[] out = new long[1]; + JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f, out)); + handle = out[0]; + } + + /** + * used for DMatrix slice + */ + protected DMatrix(long handle) { + this.handle = handle; + } + + + /** + * set label of dmatrix + * + * @param labels labels + * @throws XGBoostError native error + */ + public void setLabel(float[] labels) throws XGBoostError { + JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "label", labels)); + } + + /** + * set weight of each instance + * + * @param weights weights + * @throws XGBoostError native error + */ + public void setWeight(float[] weights) throws XGBoostError { + JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights)); + } + + /** + * if specified, xgboost will start from this init margin + * can be used to specify initial prediction to boost from + * + * @param baseMargin base margin + * @throws XGBoostError native error + */ + public void setBaseMargin(float[] baseMargin) throws XGBoostError { + JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin)); + } + + /** + * if specified, xgboost will start from this init margin + * can be used to specify initial prediction to boost from + * + * @param baseMargin base margin + * @throws XGBoostError native error + */ + public void setBaseMargin(float[][] baseMargin) throws XGBoostError { + float[] flattenMargin = flatten(baseMargin); + setBaseMargin(flattenMargin); + } + + /** + * Set group sizes of DMatrix (used for ranking) + * + * @param group group size as array + * @throws XGBoostError native error + */ + public void setGroup(int[] group) throws XGBoostError { + JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixSetGroup(handle, group)); + } + + private float[] getFloatInfo(String field) throws XGBoostError { + float[][] infos = new float[1][]; + JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixGetFloatInfo(handle, field, infos)); + return infos[0]; + } + + private int[] getIntInfo(String field) throws XGBoostError { + int[][] infos = new int[1][]; + JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixGetUIntInfo(handle, field, infos)); + return infos[0]; + } + + /** + * get label values + * + * @return label + * @throws XGBoostError native error + */ + public float[] getLabel() throws XGBoostError { + return getFloatInfo("label"); + } + + /** + * get weight of the DMatrix + * + * @return weights + * @throws XGBoostError native error + */ + public float[] getWeight() throws XGBoostError { + return getFloatInfo("weight"); + } + + /** + * get base margin of the DMatrix + * + * @return base margin + * @throws XGBoostError native error + */ + public float[] getBaseMargin() throws XGBoostError { + return getFloatInfo("base_margin"); + } + + /** + * Slice the DMatrix and return a new DMatrix that only contains `rowIndex`. + * + * @param rowIndex row index + * @return sliced new DMatrix + * @throws XGBoostError native error + */ + public DMatrix slice(int[] rowIndex) throws XGBoostError { + long[] out = new long[1]; + JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixSliceDMatrix(handle, rowIndex, out)); + long sHandle = out[0]; + DMatrix sMatrix = new DMatrix(sHandle); + return sMatrix; + } + + /** + * get the row number of DMatrix + * + * @return number of rows + * @throws XGBoostError native error + */ + public long rowNum() throws XGBoostError { + long[] rowNum = new long[1]; + JNIErrorHandle.checkCall(XgboostJNI.XGDMatrixNumRow(handle, rowNum)); + return rowNum[0]; + } + + /** + * save DMatrix to filePath + */ + public void saveBinary(String filePath) { + XgboostJNI.XGDMatrixSaveBinary(handle, filePath, 1); + } + + /** + * Get the handle + */ + public long getHandle() { + return handle; + } + + /** + * flatten a mat to array + */ + private static float[] flatten(float[][] mat) { + int size = 0; + for (float[] array : mat) size += array.length; + float[] result = new float[size]; + int pos = 0; + for (float[] ar : mat) { + System.arraycopy(ar, 0, result, pos, ar.length); + pos += ar.length; + } + + return result; + } + + @Override + protected void finalize() { + dispose(); + } + + public synchronized void dispose() { + if (handle != 0) { + XgboostJNI.XGDMatrixFree(handle); + handle = 0; + } + } +} diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/IEvaluation.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/IEvaluation.java similarity index 59% rename from java/xgboost4j/src/main/java/org/dmlc/xgboost4j/IEvaluation.java rename to jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/IEvaluation.java index 3793bff41..079cd057e 100644 --- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/IEvaluation.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/IEvaluation.java @@ -1,10 +1,10 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software @@ -13,29 +13,27 @@ See the License for the specific language governing permissions and limitations under the License. */ -package org.dmlc.xgboost4j; +package ml.dmlc.xgboost4j; /** * interface for customized evaluation - * + * * @author hzx */ public interface IEvaluation { - /** - * get evaluate metric - * - * @return evalMetric - */ - public abstract String getMetric(); + /** + * get evaluate metric + * + * @return evalMetric + */ + String getMetric(); - /** - * evaluate with predicts and data - * - * @param predicts - * predictions as array - * @param dmat - * data matrix to evaluate - * @return result of the metric - */ - public abstract float eval(float[][] predicts, DMatrix dmat); + /** + * evaluate with predicts and data + * + * @param predicts predictions as array + * @param dmat data matrix to evaluate + * @return result of the metric + */ + float eval(float[][] predicts, DMatrix dmat); } diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/IObjective.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/IObjective.java similarity index 60% rename from java/xgboost4j/src/main/java/org/dmlc/xgboost4j/IObjective.java rename to jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/IObjective.java index 640f46e6d..97ef9aed4 100644 --- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/IObjective.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/IObjective.java @@ -1,10 +1,10 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software @@ -13,20 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. */ -package org.dmlc.xgboost4j; +package ml.dmlc.xgboost4j; import java.util.List; /** * interface for customize Object function + * * @author hzx */ public interface IObjective { - /** - * user define objective function, return gradient and second order gradient - * @param predicts untransformed margin predicts - * @param dtrain training data - * @return List with two float array, correspond to first order grad and second order grad - */ - public abstract List getGradient(float[][] predicts, DMatrix dtrain); + /** + * user define objective function, return gradient and second order gradient + * + * @param predicts untransformed margin predicts + * @param dtrain training data + * @return List with two float array, correspond to first order grad and second order grad + */ + List getGradient(float[][] predicts, DMatrix dtrain); } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/JNIErrorHandle.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/JNIErrorHandle.java new file mode 100644 index 000000000..06474dbb4 --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/JNIErrorHandle.java @@ -0,0 +1,51 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package ml.dmlc.xgboost4j; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import java.io.IOException; + +/** + * Error handle for Xgboost. + */ +class JNIErrorHandle { + + private static final Log logger = LogFactory.getLog(DMatrix.class); + + //load native library + static { + try { + NativeLibLoader.initXgBoost(); + } catch (IOException ex) { + logger.error("load native library failed."); + logger.error(ex); + } + } + + /** + * Check the return value of C API. + * + * @param ret return valud of xgboostJNI C API call + * @throws XGBoostError native error + */ + static void checkCall(int ret) throws XGBoostError { + if (ret != 0) { + throw new XGBoostError(XgboostJNI.XGBGetLastError()); + } + } +} diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/JavaBoosterImpl.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/JavaBoosterImpl.java new file mode 100644 index 000000000..321b7fead --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/JavaBoosterImpl.java @@ -0,0 +1,470 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package ml.dmlc.xgboost4j; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import java.io.*; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + + +/** + * Booster for xgboost, similar to the python wrapper xgboost.py + * but custom obj function and eval function not supported at present. + * + * @author hzx + */ +class JavaBoosterImpl implements Booster { + private static final Log logger = LogFactory.getLog(JavaBoosterImpl.class); + + long handle = 0; + + //load native library + static { + try { + NativeLibLoader.initXgBoost(); + } catch (IOException ex) { + logger.error("load native library failed."); + logger.error(ex); + } + } + + /** + * init Booster from dMatrixs + * + * @param params parameters + * @param dMatrixs DMatrix array + * @throws XGBoostError native error + */ + JavaBoosterImpl(Map params, DMatrix[] dMatrixs) throws XGBoostError { + init(dMatrixs); + setParam("seed", "0"); + setParams(params); + } + + + /** + * load model from modelPath + * + * @param params parameters + * @param modelPath booster modelPath (model generated by booster.saveModel) + * @throws XGBoostError native error + */ + JavaBoosterImpl(Map params, String modelPath) throws XGBoostError { + init(null); + if (modelPath == null) { + throw new NullPointerException("modelPath : null"); + } + loadModel(modelPath); + setParam("seed", "0"); + setParams(params); + } + + + private void init(DMatrix[] dMatrixs) throws XGBoostError { + long[] handles = null; + if (dMatrixs != null) { + handles = dmatrixsToHandles(dMatrixs); + } + long[] out = new long[1]; + JNIErrorHandle.checkCall(XgboostJNI.XGBoosterCreate(handles, out)); + + handle = out[0]; + } + + /** + * set parameter + * + * @param key param name + * @param value param value + * @throws XGBoostError native error + */ + public final void setParam(String key, String value) throws XGBoostError { + JNIErrorHandle.checkCall(XgboostJNI.XGBoosterSetParam(handle, key, value)); + } + + /** + * set parameters + * + * @param params parameters key-value map + * @throws XGBoostError native error + */ + public void setParams(Map params) throws XGBoostError { + if (params != null) { + for (Map.Entry entry : params.entrySet()) { + setParam(entry.getKey(), entry.getValue().toString()); + } + } + } + + + /** + * Update (one iteration) + * + * @param dtrain training data + * @param iter current iteration number + * @throws XGBoostError native error + */ + public void update(DMatrix dtrain, int iter) throws XGBoostError { + JNIErrorHandle.checkCall(XgboostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle())); + } + + /** + * update with customize obj func + * + * @param dtrain training data + * @param obj customized objective class + * @throws XGBoostError native error + */ + public void update(DMatrix dtrain, IObjective obj) throws XGBoostError { + float[][] predicts = predict(dtrain, true); + List gradients = obj.getGradient(predicts, dtrain); + boost(dtrain, gradients.get(0), gradients.get(1)); + } + + /** + * update with give grad and hess + * + * @param dtrain training data + * @param grad first order of gradient + * @param hess seconde order of gradient + * @throws XGBoostError native error + */ + public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError { + if (grad.length != hess.length) { + throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length, + hess.length)); + } + JNIErrorHandle.checkCall(XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad, + hess)); + } + + /** + * evaluate with given dmatrixs. + * + * @param evalMatrixs dmatrixs for evaluation + * @param evalNames name for eval dmatrixs, used for check results + * @param iter current eval iteration + * @return eval information + * @throws XGBoostError native error + */ + public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError { + long[] handles = dmatrixsToHandles(evalMatrixs); + String[] evalInfo = new String[1]; + JNIErrorHandle.checkCall(XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames, + evalInfo)); + return evalInfo[0]; + } + + /** + * evaluate with given customized Evaluation class + * + * @param evalMatrixs evaluation matrix + * @param evalNames evaluation names + * @param eval custom evaluator + * @return eval information + * @throws XGBoostError native error + */ + public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, IEvaluation eval) + throws XGBoostError { + String evalInfo = ""; + for (int i = 0; i < evalNames.length; i++) { + String evalName = evalNames[i]; + DMatrix evalMat = evalMatrixs[i]; + float evalResult = eval.eval(predict(evalMat), evalMat); + String evalMetric = eval.getMetric(); + evalInfo += String.format("\t%s-%s:%f", evalName, evalMetric, evalResult); + } + return evalInfo; + } + + /** + * base function for Predict + * + * @param data data + * @param outPutMargin output margin + * @param treeLimit limit number of trees + * @param predLeaf prediction minimum to keep leafs + * @return predict results + */ + private synchronized float[][] pred(DMatrix data, boolean outPutMargin, int treeLimit, + boolean predLeaf) throws XGBoostError { + int optionMask = 0; + if (outPutMargin) { + optionMask = 1; + } + if (predLeaf) { + optionMask = 2; + } + float[][] rawPredicts = new float[1][]; + JNIErrorHandle.checkCall(XgboostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask, + treeLimit, rawPredicts)); + int row = (int) data.rowNum(); + int col = rawPredicts[0].length / row; + float[][] predicts = new float[row][col]; + int r, c; + for (int i = 0; i < rawPredicts[0].length; i++) { + r = i / col; + c = i % col; + predicts[r][c] = rawPredicts[0][i]; + } + return predicts; + } + + /** + * Predict with data + * + * @param data dmatrix storing the input + * @return predict result + * @throws XGBoostError native error + */ + public float[][] predict(DMatrix data) throws XGBoostError { + return pred(data, false, 0, false); + } + + /** + * Predict with data + * + * @param data dmatrix storing the input + * @param outPutMargin Whether to output the raw untransformed margin value. + * @return predict result + * @throws XGBoostError native error + */ + public float[][] predict(DMatrix data, boolean outPutMargin) throws XGBoostError { + return pred(data, outPutMargin, 0, false); + } + + /** + * Predict with data + * + * @param data dmatrix storing the input + * @param outPutMargin Whether to output the raw untransformed margin value. + * @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees). + * @return predict result + * @throws XGBoostError native error + */ + public float[][] predict(DMatrix data, boolean outPutMargin, int treeLimit) throws XGBoostError { + return pred(data, outPutMargin, treeLimit, false); + } + + /** + * Predict with data + * + * @param data dmatrix storing the input + * @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees). + * @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees), + * nsample = data.numRow with each record indicating the predicted leaf index + * of each sample in each tree. + * Note that the leaf index of a tree is unique per tree, so you may find leaf 1 + * in both tree 1 and tree 0. + * @return predict result + * @throws XGBoostError native error + */ + public float[][] predict(DMatrix data, int treeLimit, boolean predLeaf) throws XGBoostError { + return pred(data, false, treeLimit, predLeaf); + } + + /** + * save model to modelPath + * + * @param modelPath model path + */ + public void saveModel(String modelPath) throws XGBoostError{ + JNIErrorHandle.checkCall(XgboostJNI.XGBoosterSaveModel(handle, modelPath)); + } + + private void loadModel(String modelPath) { + XgboostJNI.XGBoosterLoadModel(handle, modelPath); + } + + /** + * get the dump of the model as a string array + * + * @param withStats Controls whether the split statistics are output. + * @return dumped model information + * @throws XGBoostError native error + */ + private String[] getDumpInfo(boolean withStats) throws XGBoostError { + int statsFlag = 0; + if (withStats) { + statsFlag = 1; + } + String[][] modelInfos = new String[1][]; + JNIErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos)); + return modelInfos[0]; + } + + /** + * get the dump of the model as a string array + * + * @param featureMap featureMap file + * @param withStats Controls whether the split statistics are output. + * @return dumped model information + * @throws XGBoostError native error + */ + private String[] getDumpInfo(String featureMap, boolean withStats) throws XGBoostError { + int statsFlag = 0; + if (withStats) { + statsFlag = 1; + } + String[][] modelInfos = new String[1][]; + JNIErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag, + modelInfos)); + return modelInfos[0]; + } + + /** + * Dump model into a text file. + * + * @param modelPath file to save dumped model info + * @param withStats bool + * Controls whether the split statistics are output. + * @throws FileNotFoundException file not found + * @throws UnsupportedEncodingException unsupported feature + * @throws IOException error with model writing + * @throws XGBoostError native error + */ + public void dumpModel(String modelPath, boolean withStats) throws IOException, XGBoostError { + File tf = new File(modelPath); + FileOutputStream out = new FileOutputStream(tf); + BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8")); + String[] modelInfos = getDumpInfo(withStats); + + for (int i = 0; i < modelInfos.length; i++) { + writer.write("booster [" + i + "]:\n"); + writer.write(modelInfos[i]); + } + + writer.close(); + out.close(); + } + + + /** + * Dump model into a text file. + * + * @param modelPath file to save dumped model info + * @param featureMap featureMap file + * @param withStats bool + * Controls whether the split statistics are output. + * @throws FileNotFoundException exception + * @throws UnsupportedEncodingException exception + * @throws IOException exception + * @throws XGBoostError native error + */ + public void dumpModel(String modelPath, String featureMap, boolean withStats) throws + IOException, XGBoostError { + File tf = new File(modelPath); + FileOutputStream out = new FileOutputStream(tf); + BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8")); + String[] modelInfos = getDumpInfo(featureMap, withStats); + + for (int i = 0; i < modelInfos.length; i++) { + writer.write("booster [" + i + "]:\n"); + writer.write(modelInfos[i]); + } + + writer.close(); + out.close(); + } + + + /** + * get importance of each feature + * + * @return featureMap key: feature index, value: feature importance score + * @throws XGBoostError native error + */ + public Map getFeatureScore() throws XGBoostError { + String[] modelInfos = getDumpInfo(false); + Map featureScore = new HashMap(); + for (String tree : modelInfos) { + for (String node : tree.split("\n")) { + String[] array = node.split("\\["); + if (array.length == 1) { + continue; + } + String fid = array[1].split("\\]")[0]; + fid = fid.split("<")[0]; + if (featureScore.containsKey(fid)) { + featureScore.put(fid, 1 + featureScore.get(fid)); + } else { + featureScore.put(fid, 1); + } + } + } + return featureScore; + } + + + /** + * get importance of each feature + * + * @param featureMap file to save dumped model info + * @return featureMap key: feature index, value: feature importance score + * @throws XGBoostError native error + */ + public Map getFeatureScore(String featureMap) throws XGBoostError { + String[] modelInfos = getDumpInfo(featureMap, false); + Map featureScore = new HashMap(); + for (String tree : modelInfos) { + for (String node : tree.split("\n")) { + String[] array = node.split("\\["); + if (array.length == 1) { + continue; + } + String fid = array[1].split("\\]")[0]; + fid = fid.split("<")[0]; + if (featureScore.containsKey(fid)) { + featureScore.put(fid, 1 + featureScore.get(fid)); + } else { + featureScore.put(fid, 1); + } + } + } + return featureScore; + } + + /** + * transfer DMatrix array to handle array (used for native functions) + * + * @param dmatrixs + * @return handle array for input dmatrixs + */ + private static long[] dmatrixsToHandles(DMatrix[] dmatrixs) { + long[] handles = new long[dmatrixs.length]; + for (int i = 0; i < dmatrixs.length; i++) { + handles[i] = dmatrixs[i].getHandle(); + } + return handles; + } + + @Override + protected void finalize() throws Throwable { + super.finalize(); + dispose(); + } + + public synchronized void dispose() { + if (handle != 0L) { + XgboostJNI.XGBoosterFree(handle); + handle = 0; + } + } +} diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/NativeLibLoader.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/NativeLibLoader.java new file mode 100644 index 000000000..85e60b3ef --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/NativeLibLoader.java @@ -0,0 +1,170 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package ml.dmlc.xgboost4j; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import java.io.*; +import java.lang.reflect.Field; + +/** + * class to load native library + * + * @author hzx + */ +class NativeLibLoader { + private static final Log logger = LogFactory.getLog(NativeLibLoader.class); + + private static boolean initialized = false; + private static final String nativePath = "../lib/"; + private static final String nativeResourcePath = "/lib/"; + private static final String[] libNames = new String[]{"xgboost4j"}; + + public static synchronized void initXgBoost() throws IOException { + if (!initialized) { + for (String libName : libNames) { + smartLoad(libName); + } + initialized = true; + } + } + + /** + * Loads library from current JAR archive + *

+ * The file from JAR is copied into system temporary directory and then loaded. + * The temporary file is deleted after exiting. + * Method uses String as filename because the pathname is "abstract", not system-dependent. + *

+ * The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to + * {@code path}. + * + * @param path The filename inside JAR as absolute path (beginning with '/'), + * e.g. /package/File.ext + * @throws IOException If temporary file creation or read/write operation fails + * @throws IllegalArgumentException If source file (param path) does not exist + * @throws IllegalArgumentException If the path is not absolute or if the filename is shorter than + * three characters + */ + private static void loadLibraryFromJar(String path) throws IOException, IllegalArgumentException{ + + if (!path.startsWith("/")) { + throw new IllegalArgumentException("The path has to be absolute (start with '/')."); + } + + // Obtain filename from path + String[] parts = path.split("/"); + String filename = (parts.length > 1) ? parts[parts.length - 1] : null; + + // Split filename to prexif and suffix (extension) + String prefix = ""; + String suffix = null; + if (filename != null) { + parts = filename.split("\\.", 2); + prefix = parts[0]; + suffix = (parts.length > 1) ? "." + parts[parts.length - 1] : null; // Thanks, davs! :-) + } + + // Check if the filename is okay + if (filename == null || prefix.length() < 3) { + throw new IllegalArgumentException("The filename has to be at least 3 characters long."); + } + + // Prepare temporary file + File temp = File.createTempFile(prefix, suffix); + temp.deleteOnExit(); + + if (!temp.exists()) { + throw new FileNotFoundException("File " + temp.getAbsolutePath() + " does not exist."); + } + + // Prepare buffer for data copying + byte[] buffer = new byte[1024]; + int readBytes; + + // Open and check input stream + InputStream is = NativeLibLoader.class.getResourceAsStream(path); + if (is == null) { + throw new FileNotFoundException("File " + path + " was not found inside JAR."); + } + + // Open output stream and copy data between source file in JAR and the temporary file + OutputStream os = new FileOutputStream(temp); + try { + while ((readBytes = is.read(buffer)) != -1) { + os.write(buffer, 0, readBytes); + } + } finally { + // If read/write fails, close streams safely before throwing an exception + os.close(); + is.close(); + } + + // Finally, load the library + System.load(temp.getAbsolutePath()); + } + + /** + * load native library, this method will first try to load library from java.library.path, then + * try to load library in jar package. + * + * @param libName library path + * @throws IOException exception + */ + private static void smartLoad(String libName) throws IOException { + addNativeDir(nativePath); + try { + System.loadLibrary(libName); + } catch (UnsatisfiedLinkError e) { + try { + String libraryFromJar = nativeResourcePath + System.mapLibraryName(libName); + loadLibraryFromJar(libraryFromJar); + } catch (IOException e1) { + throw e1; + } + } + } + + /** + * Add libPath to java.library.path, then native library in libPath would be load properly + * + * @param libPath library path + * @throws IOException exception + */ + private static void addNativeDir(String libPath) throws IOException { + try { + Field field = ClassLoader.class.getDeclaredField("usr_paths"); + field.setAccessible(true); + String[] paths = (String[]) field.get(null); + for (String path : paths) { + if (libPath.equals(path)) { + return; + } + } + String[] tmp = new String[paths.length + 1]; + System.arraycopy(paths, 0, tmp, 0, paths.length); + tmp[paths.length] = libPath; + field.set(null, tmp); + } catch (IllegalAccessException e) { + logger.error(e.getMessage()); + throw new IOException("Failed to get permissions to set library path"); + } catch (NoSuchFieldException e) { + logger.error(e.getMessage()); + throw new IOException("Failed to get field handle to set library path"); + } + } +} diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XGBoost.java new file mode 100644 index 000000000..cea4ae5bf --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XGBoost.java @@ -0,0 +1,336 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package ml.dmlc.xgboost4j; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import java.util.*; + + +/** + * trainer for xgboost + * + * @author hzx + */ +public class XGBoost { + private static final Log logger = LogFactory.getLog(XGBoost.class); + + /** + * Train a booster with given parameters. + * + * @param params Booster params. + * @param dtrain Data to be trained. + * @param round Number of boosting iterations. + * @param watches a group of items to be evaluated during training, this allows user to watch + * performance on the validation set. + * @param obj customized objective (set to null if not used) + * @param eval customized evaluation (set to null if not used) + * @return trained booster + * @throws XGBoostError native error + */ + public static Booster train(Map params, DMatrix dtrain, int round, + Map watches, IObjective obj, + IEvaluation eval) throws XGBoostError { + + //collect eval matrixs + String[] evalNames; + DMatrix[] evalMats; + List names = new ArrayList(); + List mats = new ArrayList(); + + for (Map.Entry evalEntry : watches.entrySet()) { + names.add(evalEntry.getKey()); + mats.add(evalEntry.getValue()); + } + + evalNames = names.toArray(new String[names.size()]); + evalMats = mats.toArray(new DMatrix[mats.size()]); + + //collect all data matrixs + DMatrix[] allMats; + if (evalMats != null && evalMats.length > 0) { + allMats = new DMatrix[evalMats.length + 1]; + allMats[0] = dtrain; + System.arraycopy(evalMats, 0, allMats, 1, evalMats.length); + } else { + allMats = new DMatrix[1]; + allMats[0] = dtrain; + } + + //initialize booster + Booster booster = new JavaBoosterImpl(params, allMats); + + //begin to train + for (int iter = 0; iter < round; iter++) { + if (obj != null) { + booster.update(dtrain, obj); + } else { + booster.update(dtrain, iter); + } + + //evaluation + if (evalMats != null && evalMats.length > 0) { + String evalInfo; + if (eval != null) { + evalInfo = booster.evalSet(evalMats, evalNames, eval); + } else { + evalInfo = booster.evalSet(evalMats, evalNames, iter); + } + logger.info(evalInfo); + } + } + return booster; + } + + /** + * init Booster from dMatrixs + * + * @param params parameters + * @param dMatrixs DMatrix array + * @throws XGBoostError native error + */ + public static Booster initBoostingModel( + Map params, + DMatrix[] dMatrixs) throws XGBoostError { + return new JavaBoosterImpl(params, dMatrixs); + } + + /** + * load model from modelPath + * + * @param params parameters + * @param modelPath booster modelPath (model generated by booster.saveModel) + * @throws XGBoostError native error + */ + public static Booster loadBoostModel(Map params, String modelPath) + throws XGBoostError { + return new JavaBoosterImpl(params, modelPath); + } + + /** + * Cross-validation with given paramaters. + * + * @param params Booster params. + * @param data Data to be trained. + * @param round Number of boosting iterations. + * @param nfold Number of folds in CV. + * @param metrics Evaluation metrics to be watched in CV. + * @param obj customized objective (set to null if not used) + * @param eval customized evaluation (set to null if not used) + * @return evaluation history + * @throws XGBoostError native error + */ + public static String[] crossValiation( + Map params, + DMatrix data, + int round, + int nfold, + String[] metrics, + IObjective obj, + IEvaluation eval) throws XGBoostError { + CVPack[] cvPacks = makeNFold(data, nfold, params, metrics); + String[] evalHist = new String[round]; + String[] results = new String[cvPacks.length]; + for (int i = 0; i < round; i++) { + for (CVPack cvPack : cvPacks) { + if (obj != null) { + cvPack.update(obj); + } else { + cvPack.update(i); + } + } + + for (int j = 0; j < cvPacks.length; j++) { + if (eval != null) { + results[j] = cvPacks[j].eval(eval); + } else { + results[j] = cvPacks[j].eval(i); + } + } + + evalHist[i] = aggCVResults(results); + logger.info(evalHist[i]); + } + return evalHist; + } + + /** + * make an n-fold array of CVPack from random indices + * + * @param data original data + * @param nfold num of folds + * @param params booster parameters + * @param evalMetrics Evaluation metrics + * @return CV package array + * @throws XGBoostError native error + */ + private static CVPack[] makeNFold(DMatrix data, int nfold, Map params, + String[] evalMetrics) throws XGBoostError { + List samples = genRandPermutationNums(0, (int) data.rowNum()); + int step = samples.size() / nfold; + int[] testSlice = new int[step]; + int[] trainSlice = new int[samples.size() - step]; + int testid, trainid; + CVPack[] cvPacks = new CVPack[nfold]; + for (int i = 0; i < nfold; i++) { + testid = 0; + trainid = 0; + for (int j = 0; j < samples.size(); j++) { + if (j > (i * step) && j < (i * step + step) && testid < step) { + testSlice[testid] = samples.get(j); + testid++; + } else { + if (trainid < samples.size() - step) { + trainSlice[trainid] = samples.get(j); + trainid++; + } else { + testSlice[testid] = samples.get(j); + testid++; + } + } + } + + DMatrix dtrain = data.slice(trainSlice); + DMatrix dtest = data.slice(testSlice); + CVPack cvPack = new CVPack(dtrain, dtest, params); + //set eval types + if (evalMetrics != null) { + for (String type : evalMetrics) { + cvPack.booster.setParam("eval_metric", type); + } + } + cvPacks[i] = cvPack; + } + + return cvPacks; + } + + private static List genRandPermutationNums(int start, int end) { + List samples = new ArrayList(); + for (int i = start; i < end; i++) { + samples.add(i); + } + Collections.shuffle(samples); + return samples; + } + + /** + * Aggregate cross-validation results. + * + * @param results eval info from each data sample + * @return cross-validation eval info + */ + private static String aggCVResults(String[] results) { + Map> cvMap = new HashMap>(); + String aggResult = results[0].split("\t")[0]; + for (String result : results) { + String[] items = result.split("\t"); + for (int i = 1; i < items.length; i++) { + String[] tup = items[i].split(":"); + String key = tup[0]; + Float value = Float.valueOf(tup[1]); + if (!cvMap.containsKey(key)) { + cvMap.put(key, new ArrayList()); + } + cvMap.get(key).add(value); + } + } + + for (String key : cvMap.keySet()) { + float value = 0f; + for (Float tvalue : cvMap.get(key)) { + value += tvalue; + } + value /= cvMap.get(key).size(); + aggResult += String.format("\tcv-%s:%f", key, value); + } + + return aggResult; + } + + /** + * cross validation package for xgb + * + * @author hzx + */ + private static class CVPack { + DMatrix dtrain; + DMatrix dtest; + DMatrix[] dmats; + String[] names; + Booster booster; + + /** + * create an cross validation package + * + * @param dtrain train data + * @param dtest test data + * @param params parameters + * @throws XGBoostError native error + */ + public CVPack(DMatrix dtrain, DMatrix dtest, Map params) + throws XGBoostError { + dmats = new DMatrix[]{dtrain, dtest}; + booster = XGBoost.initBoostingModel(params, dmats); + names = new String[]{"train", "test"}; + this.dtrain = dtrain; + this.dtest = dtest; + } + + /** + * update one iteration + * + * @param iter iteration num + * @throws XGBoostError native error + */ + public void update(int iter) throws XGBoostError { + booster.update(dtrain, iter); + } + + /** + * update one iteration + * + * @param obj customized objective + * @throws XGBoostError native error + */ + public void update(IObjective obj) throws XGBoostError { + booster.update(dtrain, obj); + } + + /** + * evaluation + * + * @param iter iteration num + * @return evaluation + * @throws XGBoostError native error + */ + public String eval(int iter) throws XGBoostError { + return booster.evalSet(dmats, names, iter); + } + + /** + * evaluation + * + * @param eval customized eval + * @return evaluation + * @throws XGBoostError native error + */ + public String eval(IEvaluation eval) throws XGBoostError { + return booster.evalSet(dmats, names, eval); + } + } +} diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/XGBoostError.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XGBoostError.java similarity index 75% rename from java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/XGBoostError.java rename to jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XGBoostError.java index dc7a9a0b2..1f62b22fc 100644 --- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/XGBoostError.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XGBoostError.java @@ -1,10 +1,10 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software @@ -13,14 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. */ -package org.dmlc.xgboost4j.util; +package ml.dmlc.xgboost4j; /** * custom error class for xgboost + * * @author hzx */ -public class XGBoostError extends Exception{ - public XGBoostError(String message) { - super(message); - } +public class XGBoostError extends Exception { + public XGBoostError(String message) { + super(message); + } } diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/wrapper/XgboostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XgboostJNI.java similarity index 77% rename from java/xgboost4j/src/main/java/org/dmlc/xgboost4j/wrapper/XgboostJNI.java rename to jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XgboostJNI.java index 11cab988c..10ba1802b 100644 --- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/wrapper/XgboostJNI.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XgboostJNI.java @@ -13,38 +13,71 @@ See the License for the specific language governing permissions and limitations under the License. */ -package org.dmlc.xgboost4j.wrapper; +package ml.dmlc.xgboost4j; /** * xgboost jni wrapper functions for xgboost_wrapper.h * change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster + * * @author hzx */ -public class XgboostJNI { +class XgboostJNI { public final static native String XGBGetLastError(); + public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out); - public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data, long[] out); - public final static native int XGDMatrixCreateFromCSC(long[] colptr, int[] indices, float[] data, long[] out); - public final static native int XGDMatrixCreateFromMat(float[] data, int nrow, int ncol, float missing, long[] out); + + public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data, + long[] out); + + public final static native int XGDMatrixCreateFromCSC(long[] colptr, int[] indices, float[] data, + long[] out); + + public final static native int XGDMatrixCreateFromMat(float[] data, int nrow, int ncol, + float missing, long[] out); + public final static native int XGDMatrixSliceDMatrix(long handle, int[] idxset, long[] out); + public final static native int XGDMatrixFree(long handle); + public final static native int XGDMatrixSaveBinary(long handle, String fname, int silent); + public final static native int XGDMatrixSetFloatInfo(long handle, String field, float[] array); + public final static native int XGDMatrixSetUIntInfo(long handle, String field, int[] array); + public final static native int XGDMatrixSetGroup(long handle, int[] group); + public final static native int XGDMatrixGetFloatInfo(long handle, String field, float[][] info); + public final static native int XGDMatrixGetUIntInfo(long handle, String filed, int[][] info); + public final static native int XGDMatrixNumRow(long handle, long[] row); + public final static native int XGBoosterCreate(long[] handles, long[] out); + public final static native int XGBoosterFree(long handle); + public final static native int XGBoosterSetParam(long handle, String name, String value); + public final static native int XGBoosterUpdateOneIter(long handle, int iter, long dtrain); - public final static native int XGBoosterBoostOneIter(long handle, long dtrain, float[] grad, float[] hess); - public final static native int XGBoosterEvalOneIter(long handle, int iter, long[] dmats, String[] evnames, String[] eval_info); - public final static native int XGBoosterPredict(long handle, long dmat, int option_mask, int ntree_limit, float[][] predicts); + + public final static native int XGBoosterBoostOneIter(long handle, long dtrain, float[] grad, + float[] hess); + + public final static native int XGBoosterEvalOneIter(long handle, int iter, long[] dmats, + String[] evnames, String[] eval_info); + + public final static native int XGBoosterPredict(long handle, long dmat, int option_mask, + int ntree_limit, float[][] predicts); + public final static native int XGBoosterLoadModel(long handle, String fname); + public final static native int XGBoosterSaveModel(long handle, String fname); + public final static native int XGBoosterLoadModelFromBuffer(long handle, long buf, long len); + public final static native int XGBoosterGetModelRaw(long handle, String[] out_string); - public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats, String[][] out_strings); + + public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats, + String[][] out_strings); } diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala new file mode 100644 index 000000000..5d5cd5619 --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/Booster.scala @@ -0,0 +1,189 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala + +import java.io.IOException + +import scala.collection.mutable + +import ml.dmlc.xgboost4j.XGBoostError + + +trait Booster { + + + /** + * set parameter + * + * @param key param name + * @param value param value + */ + @throws(classOf[XGBoostError]) + def setParam(key: String, value: String) + + /** + * set parameters + * + * @param params parameters key-value map + */ + @throws(classOf[XGBoostError]) + def setParams(params: Map[String, AnyRef]) + + /** + * Update (one iteration) + * + * @param dtrain training data + * @param iter current iteration number + */ + @throws(classOf[XGBoostError]) + def update(dtrain: DMatrix, iter: Int) + + /** + * update with customize obj func + * + * @param dtrain training data + * @param obj customized objective class + */ + @throws(classOf[XGBoostError]) + def update(dtrain: DMatrix, obj: ObjectiveTrait) + + /** + * update with give grad and hess + * + * @param dtrain training data + * @param grad first order of gradient + * @param hess seconde order of gradient + */ + @throws(classOf[XGBoostError]) + def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]) + + /** + * evaluate with given dmatrixs. + * + * @param evalMatrixs dmatrixs for evaluation + * @param evalNames name for eval dmatrixs, used for check results + * @param iter current eval iteration + * @return eval information + */ + @throws(classOf[XGBoostError]) + def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int): String + + /** + * evaluate with given customized Evaluation class + * + * @param evalMatrixs evaluation matrix + * @param evalNames evaluation names + * @param eval custom evaluator + * @return eval information + */ + @throws(classOf[XGBoostError]) + def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: EvalTrait): String + + /** + * Predict with data + * + * @param data dmatrix storing the input + * @return predict result + */ + @throws(classOf[XGBoostError]) + def predict(data: DMatrix): Array[Array[Float]] + + /** + * Predict with data + * + * @param data dmatrix storing the input + * @param outPutMargin Whether to output the raw untransformed margin value. + * @return predict result + */ + @throws(classOf[XGBoostError]) + def predict(data: DMatrix, outPutMargin: Boolean): Array[Array[Float]] + + /** + * Predict with data + * + * @param data dmatrix storing the input + * @param outPutMargin Whether to output the raw untransformed margin value. + * @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees). + * @return predict result + */ + @throws(classOf[XGBoostError]) + def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int): Array[Array[Float]] + + /** + * Predict with data + * + * @param data dmatrix storing the input + * @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees). + * @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees), + * nsample = data.numRow with each record indicating the predicted leaf index of + * each sample in each tree. Note that the leaf index of a tree is unique per + * tree, so you may find leaf 1 in both tree 1 and tree 0. + * @return predict result + * @throws XGBoostError native error + */ + @throws(classOf[XGBoostError]) + def predict(data: DMatrix, treeLimit: Int, predLeaf: Boolean): Array[Array[Float]] + + /** + * save model to modelPath + * + * @param modelPath model path + */ + @throws(classOf[XGBoostError]) + def saveModel(modelPath: String) + + /** + * Dump model into a text file. + * + * @param modelPath file to save dumped model info + * @param withStats bool Controls whether the split statistics are output. + */ + @throws(classOf[IOException]) + @throws(classOf[XGBoostError]) + def dumpModel(modelPath: String, withStats: Boolean) + + /** + * Dump model into a text file. + * + * @param modelPath file to save dumped model info + * @param featureMap featureMap file + * @param withStats bool + * Controls whether the split statistics are output. + */ + @throws(classOf[IOException]) + @throws(classOf[XGBoostError]) + def dumpModel(modelPath: String, featureMap: String, withStats: Boolean) + + /** + * get importance of each feature + * + * @return featureMap key: feature index, value: feature importance score + */ + @throws(classOf[XGBoostError]) + def getFeatureScore: mutable.Map[String, Integer] + + /** + * get importance of each feature + * + * @param featureMap file to save dumped model info + * @return featureMap key: feature index, value: feature importance score + */ + @throws(classOf[XGBoostError]) + def getFeatureScore(featureMap: String): mutable.Map[String, Integer] + + def dispose +} diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala new file mode 100644 index 000000000..73fafc7f0 --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/DMatrix.scala @@ -0,0 +1,177 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala + +import ml.dmlc.xgboost4j.{DMatrix => JDMatrix, XGBoostError} + +class DMatrix private(private[scala] val jDMatrix: JDMatrix) { + + /** + * init DMatrix from file (svmlight format) + * + * @param dataPath path of data file + * @throws XGBoostError native error + */ + def this(dataPath: String) { + this(new JDMatrix(dataPath)) + } + + /** + * create DMatrix from sparse matrix + * + * @param headers index to headers (rowHeaders for CSR or colHeaders for CSC) + * @param indices Indices (colIndexs for CSR or rowIndexs for CSC) + * @param data non zero values (sequence by row for CSR or by col for CSC) + * @param st sparse matrix type (CSR or CSC) + */ + @throws(classOf[XGBoostError]) + def this(headers: Array[Long], indices: Array[Int], data: Array[Float], st: JDMatrix.SparseType) { + this(new JDMatrix(headers, indices, data, st)) + } + + /** + * create DMatrix from dense matrix + * + * @param data data values + * @param nrow number of rows + * @param ncol number of columns + */ + @throws(classOf[XGBoostError]) + def this(data: Array[Float], nrow: Int, ncol: Int) { + this(new JDMatrix(data, nrow, ncol)) + } + + /** + * set label of dmatrix + * + * @param labels labels + */ + @throws(classOf[XGBoostError]) + def setLabel(labels: Array[Float]): Unit = { + jDMatrix.setLabel(labels) + } + + /** + * set weight of each instance + * + * @param weights weights + */ + @throws(classOf[XGBoostError]) + def setWeight(weights: Array[Float]): Unit = { + jDMatrix.setWeight(weights) + } + + /** + * if specified, xgboost will start from this init margin + * can be used to specify initial prediction to boost from + * + * @param baseMargin base margin + */ + @throws(classOf[XGBoostError]) + def setBaseMargin(baseMargin: Array[Float]): Unit = { + jDMatrix.setBaseMargin(baseMargin) + } + + /** + * if specified, xgboost will start from this init margin + * can be used to specify initial prediction to boost from + * + * @param baseMargin base margin + */ + @throws(classOf[XGBoostError]) + def setBaseMargin(baseMargin: Array[Array[Float]]): Unit = { + jDMatrix.setBaseMargin(baseMargin) + } + + /** + * Set group sizes of DMatrix (used for ranking) + * + * @param group group size as array + */ + @throws(classOf[XGBoostError]) + def setGroup(group: Array[Int]): Unit = { + jDMatrix.setGroup(group) + } + + /** + * get label values + * + * @return label + */ + @throws(classOf[XGBoostError]) + def getLabel: Array[Float] = { + jDMatrix.getLabel + } + + /** + * get weight of the DMatrix + * + * @return weights + */ + @throws(classOf[XGBoostError]) + def getWeight: Array[Float] = { + jDMatrix.getWeight + } + + /** + * get base margin of the DMatrix + * + * @return base margin + */ + @throws(classOf[XGBoostError]) + def getBaseMargin: Array[Float] = { + jDMatrix.getBaseMargin + } + + /** + * Slice the DMatrix and return a new DMatrix that only contains `rowIndex`. + * + * @param rowIndex row index + * @return sliced new DMatrix + */ + @throws(classOf[XGBoostError]) + def slice(rowIndex: Array[Int]): DMatrix = { + new DMatrix(jDMatrix.slice(rowIndex)) + } + + /** + * get the row number of DMatrix + * + * @return number of rows + */ + @throws(classOf[XGBoostError]) + def rowNum: Long = { + jDMatrix.rowNum + } + + /** + * save DMatrix to filePath + * + * @param filePath file path + */ + def saveBinary(filePath: String): Unit = { + jDMatrix.saveBinary(filePath) + } + + def getHandle: Long = { + jDMatrix.getHandle + } + + def delete(): Unit = { + jDMatrix.dispose() + } +} diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/EvalTrait.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/EvalTrait.scala new file mode 100644 index 000000000..461f515a1 --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/EvalTrait.scala @@ -0,0 +1,38 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala + +import ml.dmlc.xgboost4j.IEvaluation + +trait EvalTrait extends IEvaluation { + + /** + * get evaluate metric + * + * @return evalMetric + */ + def getMetric: String + + /** + * evaluate with predicts and data + * + * @param predicts predictions as array + * @param dmat data matrix to evaluate + * @return result of the metric + */ + def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float +} diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ObjectiveTrait.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ObjectiveTrait.scala new file mode 100644 index 000000000..c5df8aead --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ObjectiveTrait.scala @@ -0,0 +1,30 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala + +import ml.dmlc.xgboost4j.IObjective + +trait ObjectiveTrait extends IObjective { + /** + * user define objective function, return gradient and second order gradient + * + * @param predicts untransformed margin predicts + * @param dtrain training data + * @return List with two float array, correspond to first order grad and second order grad + */ + def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix): java.util.List[Array[Float]] +} diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImpl.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImpl.scala new file mode 100644 index 000000000..06af4541b --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/ScalaBoosterImpl.scala @@ -0,0 +1,100 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import ml.dmlc.xgboost4j.{Booster => JBooster, IEvaluation, IObjective} + +private[scala] class ScalaBoosterImpl private[xgboost4j](booster: JBooster) extends Booster { + + override def setParam(key: String, value: String): Unit = { + booster.setParam(key, value) + } + + override def update(dtrain: DMatrix, iter: Int): Unit = { + booster.update(dtrain.jDMatrix, iter) + } + + override def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = { + booster.update(dtrain.jDMatrix, obj) + } + + override def dumpModel(modelPath: String, withStats: Boolean): Unit = { + booster.dumpModel(modelPath, withStats) + } + + override def dumpModel(modelPath: String, featureMap: String, withStats: Boolean): Unit = { + booster.dumpModel(modelPath, featureMap, withStats) + } + + override def setParams(params: Map[String, AnyRef]): Unit = { + booster.setParams(params.asJava) + } + + override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int): String = { + booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, iter) + } + + override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: EvalTrait): + String = { + booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, eval) + } + + override def dispose: Unit = { + booster.dispose() + } + + override def predict(data: DMatrix): Array[Array[Float]] = { + booster.predict(data.jDMatrix) + } + + override def predict(data: DMatrix, outPutMargin: Boolean): Array[Array[Float]] = { + booster.predict(data.jDMatrix, outPutMargin) + } + + override def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int): + Array[Array[Float]] = { + booster.predict(data.jDMatrix, outPutMargin, treeLimit) + } + + override def predict(data: DMatrix, treeLimit: Int, predLeaf: Boolean): Array[Array[Float]] = { + booster.predict(data.jDMatrix, treeLimit, predLeaf) + } + + override def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = { + booster.boost(dtrain.jDMatrix, grad, hess) + } + + override def getFeatureScore: mutable.Map[String, Integer] = { + booster.getFeatureScore.asScala + } + + override def getFeatureScore(featureMap: String): mutable.Map[String, Integer] = { + booster.getFeatureScore(featureMap).asScala + } + + override def saveModel(modelPath: String): Unit = { + booster.saveModel(modelPath) + } + + override def finalize(): Unit = { + super.finalize() + dispose + } +} diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala new file mode 100644 index 000000000..737e4765d --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/scala/XGBoost.scala @@ -0,0 +1,52 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala + +import _root_.scala.collection.JavaConverters._ +import ml.dmlc.xgboost4j.{XGBoost => JXGBoost} + +object XGBoost { + + def train(params: Map[String, AnyRef], dtrain: DMatrix, round: Int, + watches: Map[String, DMatrix], obj: ObjectiveTrait, eval: EvalTrait): Booster = { + val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)} + val xgboostInJava = JXGBoost.train(params.asJava, dtrain.jDMatrix, round, jWatches.asJava, + obj, eval) + new ScalaBoosterImpl(xgboostInJava) + } + + def crossValiation( + params: Map[String, AnyRef], + data: DMatrix, + round: Int, + nfold: Int, + metrics: Array[String], + obj: ObjectiveTrait, + eval: EvalTrait): Array[String] = { + JXGBoost.crossValiation(params.asJava, data.jDMatrix, round, nfold, metrics, obj, eval) + } + + def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = { + val xgboostInJava = JXGBoost.initBoostingModel(params.asJava, dMatrixs.map(_.jDMatrix)) + new ScalaBoosterImpl(xgboostInJava) + } + + def loadBoostModel(params: Map[String, AnyRef], modelPath: String): Booster = { + val xgboostInJava = JXGBoost.loadBoostModel(params.asJava, modelPath) + new ScalaBoosterImpl(xgboostInJava) + } +} diff --git a/java/xgboost4j_wrapper.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp similarity index 81% rename from java/xgboost4j_wrapper.cpp rename to jvm-packages/xgboost4j/src/native/xgboost4j.cpp index 865426752..0d976a33f 100644 --- a/java/xgboost4j_wrapper.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -13,7 +13,7 @@ */ #include "xgboost/c_api.h" -#include "xgboost4j_wrapper.h" +#include "xgboost4j.h" #include //helper functions @@ -24,7 +24,7 @@ void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) { jenv->SetLongArrayRegion(jhandle, 0, 1, (const jlong*) out); } -JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBGetLastError +JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError (JNIEnv *jenv, jclass jcls) { jstring jresult = 0 ; const char* result = XGBGetLastError(); @@ -32,7 +32,7 @@ JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBGetLastE return jresult; } -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromFile +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile (JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent, jlongArray jout) { DMatrixHandle result; const char* fname = jenv->GetStringUTFChars(jfname, 0); @@ -43,11 +43,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreat } /* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Class: ml_dmlc_xgboost4j_XgboostJNI * Method: XGDMatrixCreateFromCSR * Signature: ([J[J[F)J */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSR +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSR (JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) { DMatrixHandle result; jlong* indptr = jenv->GetLongArrayElements(jindptr, 0); @@ -65,11 +65,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreat } /* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Class: ml_dmlc_xgboost4j_XgboostJNI * Method: XGDMatrixCreateFromCSC * Signature: ([J[J[F)J */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSC +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSC (JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) { DMatrixHandle result; jlong* indptr = jenv->GetLongArrayElements(jindptr, NULL); @@ -89,11 +89,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreat } /* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Class: ml_dmlc_xgboost4j_XgboostJNI * Method: XGDMatrixCreateFromMat * Signature: ([FIIF)J */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromMat +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromMat (JNIEnv *jenv, jclass jcls, jfloatArray jdata, jint jnrow, jint jncol, jfloat jmiss, jlongArray jout) { DMatrixHandle result; jfloat* data = jenv->GetFloatArrayElements(jdata, 0); @@ -107,11 +107,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreat } /* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Class: ml_dmlc_xgboost4j_XgboostJNI * Method: XGDMatrixSliceDMatrix * Signature: (J[I)J */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSliceDMatrix +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSliceDMatrix (JNIEnv *jenv, jclass jcls, jlong jhandle, jintArray jindexset, jlongArray jout) { DMatrixHandle result; DMatrixHandle handle = (DMatrixHandle) jhandle; @@ -128,11 +128,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSlice } /* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Class: ml_dmlc_xgboost4j_XgboostJNI * Method: XGDMatrixFree * Signature: (J)V */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixFree +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixFree (JNIEnv *jenv, jclass jcls, jlong jhandle) { DMatrixHandle handle = (DMatrixHandle) jhandle; int ret = XGDMatrixFree(handle); @@ -140,11 +140,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixFree } /* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Class: ml_dmlc_xgboost4j_XgboostJNI * Method: XGDMatrixSaveBinary * Signature: (JLjava/lang/String;I)V */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSaveBinary +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSaveBinary (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname, jint jsilent) { DMatrixHandle handle = (DMatrixHandle) jhandle; const char* fname = jenv->GetStringUTFChars(jfname, 0); @@ -154,11 +154,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSaveB } /* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Class: ml_dmlc_xgboost4j_XgboostJNI * Method: XGDMatrixSetFloatInfo * Signature: (JLjava/lang/String;[F)V */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetFloatInfo +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetFloatInfo (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jfloatArray jarray) { DMatrixHandle handle = (DMatrixHandle) jhandle; const char* field = jenv->GetStringUTFChars(jfield, 0); @@ -173,11 +173,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetFl } /* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Class: ml_dmlc_xgboost4j_XgboostJNI * Method: XGDMatrixSetUIntInfo * Signature: (JLjava/lang/String;[I)V */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetUIntInfo +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetUIntInfo (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jintArray jarray) { DMatrixHandle handle = (DMatrixHandle) jhandle; const char* field = jenv->GetStringUTFChars(jfield, 0); @@ -192,11 +192,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetUI } /* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Class: ml_dmlc_xgboost4j_XgboostJNI * Method: XGDMatrixSetGroup * Signature: (J[I)V */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGroup +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetGroup (JNIEnv * jenv, jclass jcls, jlong jhandle, jintArray jarray) { DMatrixHandle handle = (DMatrixHandle) jhandle; jint* array = jenv->GetIntArrayElements(jarray, NULL); @@ -208,11 +208,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGr } /* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Class: ml_dmlc_xgboost4j_XgboostJNI * Method: XGDMatrixGetFloatInfo * Signature: (JLjava/lang/String;)[F */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetFloatInfo +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetFloatInfo (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) { DMatrixHandle handle = (DMatrixHandle) jhandle; const char* field = jenv->GetStringUTFChars(jfield, 0); @@ -230,11 +230,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetFl } /* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Class: ml_dmlc_xgboost4j_XgboostJNI * Method: XGDMatrixGetUIntInfo * Signature: (JLjava/lang/String;)[I */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetUIntInfo +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetUIntInfo (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) { DMatrixHandle handle = (DMatrixHandle) jhandle; const char* field = jenv->GetStringUTFChars(jfield, 0); @@ -251,11 +251,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetUI } /* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Class: ml_dmlc_xgboost4j_XgboostJNI * Method: XGDMatrixNumRow * Signature: (J)J */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixNumRow +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixNumRow (JNIEnv *jenv, jclass jcls, jlong jhandle, jlongArray jout) { DMatrixHandle handle = (DMatrixHandle) jhandle; bst_ulong result[1]; @@ -265,11 +265,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixNumRo } /* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Class: ml_dmlc_xgboost4j_XgboostJNI * Method: XGBoosterCreate * Signature: ([J)J */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCreate +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate (JNIEnv *jenv, jclass jcls, jlongArray jhandles, jlongArray jout) { DMatrixHandle* handles; bst_ulong len = 0; @@ -298,11 +298,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCreat } /* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Class: ml_dmlc_xgboost4j_XgboostJNI * Method: XGBoosterFree * Signature: (J)V */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterFree +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterFree (JNIEnv *jenv, jclass jcls, jlong jhandle) { BoosterHandle handle = (BoosterHandle) jhandle; return XGBoosterFree(handle); @@ -310,11 +310,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterFree /* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Class: ml_dmlc_xgboost4j_XgboostJNI * Method: XGBoosterSetParam * Signature: (JLjava/lang/String;Ljava/lang/String;)V */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSetParam +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSetParam (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jname, jstring jvalue) { BoosterHandle handle = (BoosterHandle) jhandle; const char* name = jenv->GetStringUTFChars(jname, 0); @@ -327,11 +327,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSetPa } /* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Class: ml_dmlc_xgboost4j_XgboostJNI * Method: XGBoosterUpdateOneIter * Signature: (JIJ)V */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdateOneIter +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterUpdateOneIter (JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlong jdtrain) { BoosterHandle handle = (BoosterHandle) jhandle; DMatrixHandle dtrain = (DMatrixHandle) jdtrain; @@ -339,11 +339,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdat } /* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Class: ml_dmlc_xgboost4j_XgboostJNI * Method: XGBoosterBoostOneIter * Signature: (JJ[F[F)V */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoostOneIter +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterBoostOneIter (JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jfloatArray jgrad, jfloatArray jhess) { BoosterHandle handle = (BoosterHandle) jhandle; DMatrixHandle dtrain = (DMatrixHandle) jdtrain; @@ -358,11 +358,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoost } /* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Class: ml_dmlc_xgboost4j_XgboostJNI * Method: XGBoosterEvalOneIter * Signature: (JI[J[Ljava/lang/String;)Ljava/lang/String; */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEvalOneIter +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterEvalOneIter (JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlongArray jdmats, jobjectArray jevnames, jobjectArray jout) { BoosterHandle handle = (BoosterHandle) jhandle; DMatrixHandle* dmats = 0; @@ -406,11 +406,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEvalO } /* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Class: ml_dmlc_xgboost4j_XgboostJNI * Method: XGBoosterPredict * Signature: (JJIJ)[F */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterPredict +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterPredict (JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdmat, jint joption_mask, jint jntree_limit, jobjectArray jout) { BoosterHandle handle = (BoosterHandle) jhandle; DMatrixHandle dmat = (DMatrixHandle) jdmat; @@ -426,11 +426,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterPredi } /* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Class: ml_dmlc_xgboost4j_XgboostJNI * Method: XGBoosterLoadModel * Signature: (JLjava/lang/String;)V */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModel +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModel (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) { BoosterHandle handle = (BoosterHandle) jhandle; const char* fname = jenv->GetStringUTFChars(jfname, 0); @@ -441,11 +441,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadM } /* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Class: ml_dmlc_xgboost4j_XgboostJNI * Method: XGBoosterSaveModel * Signature: (JLjava/lang/String;)V */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSaveModel +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveModel (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) { BoosterHandle handle = (BoosterHandle) jhandle; const char* fname = jenv->GetStringUTFChars(jfname, 0); @@ -457,11 +457,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSaveM } /* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Class: ml_dmlc_xgboost4j_XgboostJNI * Method: XGBoosterLoadModelFromBuffer * Signature: (JJJ)V */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModelFromBuffer +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromBuffer (JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jbuf, jlong jlen) { BoosterHandle handle = (BoosterHandle) jhandle; void *buf = (void*) jbuf; @@ -469,11 +469,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadM } /* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Class: ml_dmlc_xgboost4j_XgboostJNI * Method: XGBoosterGetModelRaw * Signature: (J)Ljava/lang/String; */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGetModelRaw +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw (JNIEnv * jenv, jclass jcls, jlong jhandle, jobjectArray jout) { BoosterHandle handle = (BoosterHandle) jhandle; bst_ulong len = 0; @@ -488,11 +488,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGetMo } /* - * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Class: ml_dmlc_xgboost4j_XgboostJNI * Method: XGBoosterDumpModel * Signature: (JLjava/lang/String;I)[Ljava/lang/String; */ -JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterDumpModel +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfmap, jint jwith_stats, jobjectArray jout) { BoosterHandle handle = (BoosterHandle) jhandle; const char *fmap = jenv->GetStringUTFChars(jfmap, 0); @@ -510,4 +510,4 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterDumpM if (fmap) jenv->ReleaseStringUTFChars(jfmap, (const char *)fmap); return ret; -} \ No newline at end of file +} diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.h b/jvm-packages/xgboost4j/src/native/xgboost4j.h new file mode 100644 index 000000000..d93da0ee6 --- /dev/null +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.h @@ -0,0 +1,221 @@ +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class ml_dmlc_xgboost4j_XgboostJNI */ + +#ifndef _Included_ml_dmlc_xgboost4j_XgboostJNI +#define _Included_ml_dmlc_xgboost4j_XgboostJNI +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGBGetLastError + * Signature: ()Ljava/lang/String; + */ +JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError + (JNIEnv *, jclass); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGDMatrixCreateFromFile + * Signature: (Ljava/lang/String;I[J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile + (JNIEnv *, jclass, jstring, jint, jlongArray); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGDMatrixCreateFromCSR + * Signature: ([J[I[F[J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSR + (JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jlongArray); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGDMatrixCreateFromCSC + * Signature: ([J[I[F[J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSC + (JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jlongArray); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGDMatrixCreateFromMat + * Signature: ([FIIF[J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromMat + (JNIEnv *, jclass, jfloatArray, jint, jint, jfloat, jlongArray); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGDMatrixSliceDMatrix + * Signature: (J[I[J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSliceDMatrix + (JNIEnv *, jclass, jlong, jintArray, jlongArray); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGDMatrixFree + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixFree + (JNIEnv *, jclass, jlong); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGDMatrixSaveBinary + * Signature: (JLjava/lang/String;I)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSaveBinary + (JNIEnv *, jclass, jlong, jstring, jint); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGDMatrixSetFloatInfo + * Signature: (JLjava/lang/String;[F)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetFloatInfo + (JNIEnv *, jclass, jlong, jstring, jfloatArray); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGDMatrixSetUIntInfo + * Signature: (JLjava/lang/String;[I)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetUIntInfo + (JNIEnv *, jclass, jlong, jstring, jintArray); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGDMatrixSetGroup + * Signature: (J[I)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetGroup + (JNIEnv *, jclass, jlong, jintArray); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGDMatrixGetFloatInfo + * Signature: (JLjava/lang/String;[[F)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetFloatInfo + (JNIEnv *, jclass, jlong, jstring, jobjectArray); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGDMatrixGetUIntInfo + * Signature: (JLjava/lang/String;[[I)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetUIntInfo + (JNIEnv *, jclass, jlong, jstring, jobjectArray); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGDMatrixNumRow + * Signature: (J[J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixNumRow + (JNIEnv *, jclass, jlong, jlongArray); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGBoosterCreate + * Signature: ([J[J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate + (JNIEnv *, jclass, jlongArray, jlongArray); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGBoosterFree + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterFree + (JNIEnv *, jclass, jlong); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGBoosterSetParam + * Signature: (JLjava/lang/String;Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSetParam + (JNIEnv *, jclass, jlong, jstring, jstring); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGBoosterUpdateOneIter + * Signature: (JIJ)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterUpdateOneIter + (JNIEnv *, jclass, jlong, jint, jlong); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGBoosterBoostOneIter + * Signature: (JJ[F[F)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterBoostOneIter + (JNIEnv *, jclass, jlong, jlong, jfloatArray, jfloatArray); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGBoosterEvalOneIter + * Signature: (JI[J[Ljava/lang/String;[Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterEvalOneIter + (JNIEnv *, jclass, jlong, jint, jlongArray, jobjectArray, jobjectArray); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGBoosterPredict + * Signature: (JJII[[F)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterPredict + (JNIEnv *, jclass, jlong, jlong, jint, jint, jobjectArray); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGBoosterLoadModel + * Signature: (JLjava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModel + (JNIEnv *, jclass, jlong, jstring); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGBoosterSaveModel + * Signature: (JLjava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveModel + (JNIEnv *, jclass, jlong, jstring); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGBoosterLoadModelFromBuffer + * Signature: (JJJ)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromBuffer + (JNIEnv *, jclass, jlong, jlong, jlong); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGBoosterGetModelRaw + * Signature: (J[Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw + (JNIEnv *, jclass, jlong, jobjectArray); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGBoosterDumpModel + * Signature: (JLjava/lang/String;I[[Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel + (JNIEnv *, jclass, jlong, jstring, jint, jobjectArray); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/BoosterImplTest.java new file mode 100644 index 000000000..e44bc95bc --- /dev/null +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/BoosterImplTest.java @@ -0,0 +1,138 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package ml.dmlc.xgboost4j; + +import junit.framework.TestCase; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.junit.Test; + +import java.util.*; + +/** + * test cases for Booster + * + * @author hzx + */ +public class BoosterImplTest { + public static class EvalError implements IEvaluation { + private static final Log logger = LogFactory.getLog(EvalError.class); + + String evalMetric = "custom_error"; + + public EvalError() { + } + + @Override + public String getMetric() { + return evalMetric; + } + + @Override + public float eval(float[][] predicts, DMatrix dmat) { + float error = 0f; + float[] labels; + try { + labels = dmat.getLabel(); + } catch (XGBoostError ex) { + logger.error(ex); + return -1f; + } + int nrow = predicts.length; + for (int i = 0; i < nrow; i++) { + if (labels[i] == 0f && predicts[i][0] > 0) { + error++; + } else if (labels[i] == 1f && predicts[i][0] <= 0) { + error++; + } + } + + return error / labels.length; + } + } + + @Test + public void testBoosterBasic() throws XGBoostError { + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + + //set params + Map paramMap = new HashMap() { + { + put("eta", 1.0); + put("max_depth", 2); + put("silent", 1); + put("objective", "binary:logistic"); + } + }; + + //set watchList + HashMap watches = new HashMap<>(); + + watches.put("train", trainMat); + watches.put("test", testMat); + + //set round + int round = 2; + + //train a boost model + Booster booster = XGBoost.train(paramMap, trainMat, round, watches, null, null); + + //predict raw output + float[][] predicts = booster.predict(testMat, true); + + //eval + IEvaluation eval = new EvalError(); + //error must be less than 0.1 + TestCase.assertTrue(eval.eval(predicts, testMat) < 0.1f); + + //test dump model + + } + + /** + * test cross valiation + * + * @throws XGBoostError + */ + @Test + public void testCV() throws XGBoostError { + //load train mat + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + + //set params + Map param = new HashMap() { + { + put("eta", 1.0); + put("max_depth", 3); + put("silent", 1); + put("nthread", 6); + put("objective", "binary:logistic"); + put("gamma", 1.0); + put("eval_metric", "error"); + } + }; + + //do 5-fold cross validation + int round = 2; + int nfold = 5; + //set additional eval_metrics + String[] metrics = null; + + String[] evalHist = XGBoost.crossValiation(param, trainMat, round, nfold, metrics, + null, null); + } +} diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/DMatrixTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/DMatrixTest.java new file mode 100644 index 000000000..9b3a8b860 --- /dev/null +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/DMatrixTest.java @@ -0,0 +1,103 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package ml.dmlc.xgboost4j; + +import junit.framework.TestCase; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Random; + +/** + * test cases for DMatrix + * + * @author hzx + */ +public class DMatrixTest { + + @Test + public void testCreateFromFile() throws XGBoostError { + //create DMatrix from file + DMatrix dmat = new DMatrix("../../demo/data/agaricus.txt.test"); + //get label + float[] labels = dmat.getLabel(); + //check length + TestCase.assertTrue(dmat.rowNum() == labels.length); + //set weights + float[] weights = Arrays.copyOf(labels, labels.length); + dmat.setWeight(weights); + float[] dweights = dmat.getWeight(); + TestCase.assertTrue(Arrays.equals(weights, dweights)); + } + + @Test + public void testCreateFromCSR() throws XGBoostError { + //create Matrix from csr format sparse Matrix and labels + /** + * sparse matrix + * 1 0 2 3 0 + * 4 0 2 3 5 + * 3 1 2 5 0 + */ + float[] data = new float[]{1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5}; + int[] colIndex = new int[]{0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3}; + long[] rowHeaders = new long[]{0, 3, 7, 11}; + DMatrix dmat1 = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR); + //check row num + System.out.println(dmat1.rowNum()); + TestCase.assertTrue(dmat1.rowNum() == 3); + //test set label + float[] label1 = new float[]{1, 0, 1}; + dmat1.setLabel(label1); + float[] label2 = dmat1.getLabel(); + TestCase.assertTrue(Arrays.equals(label1, label2)); + } + + @Test + public void testCreateFromDenseMatrix() throws XGBoostError { + //create DMatrix from 10*5 dense matrix + int nrow = 10; + int ncol = 5; + float[] data0 = new float[nrow * ncol]; + //put random nums + Random random = new Random(); + for (int i = 0; i < nrow * ncol; i++) { + data0[i] = random.nextFloat(); + } + + //create label + float[] label0 = new float[nrow]; + for (int i = 0; i < nrow; i++) { + label0[i] = random.nextFloat(); + } + + DMatrix dmat0 = new DMatrix(data0, nrow, ncol); + dmat0.setLabel(label0); + + //check + TestCase.assertTrue(dmat0.rowNum() == 10); + TestCase.assertTrue(dmat0.getLabel().length == 10); + + //set weights for each instance + float[] weights = new float[nrow]; + for (int i = 0; i < nrow; i++) { + weights[i] = random.nextFloat(); + } + dmat0.setWeight(weights); + + TestCase.assertTrue(Arrays.equals(weights, dmat0.getWeight())); + } +} diff --git a/tests/travis/run_test.sh b/tests/travis/run_test.sh index 5795d89ff..bf3a781e7 100755 --- a/tests/travis/run_test.sh +++ b/tests/travis/run_test.sh @@ -73,10 +73,9 @@ fi if [ ${TASK} == "java_test" ]; then set -e - make java - cd java - ./create_wrap.sh - cd xgboost4j + make jvm-packages + cd jvm-packages + ./create_jni.sh mvn clean install -DskipTests=true mvn test fi