From 55e36893cdd119016e7bb182c01e3aa37c6d7ae4 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Tue, 1 Mar 2016 20:19:49 -0500 Subject: [PATCH] add style check for java and scala code --- jvm-packages/checkstyle-suppressions.xml | 33 ++ jvm-packages/checkstyle.xml | 169 ++++++++++ jvm-packages/pom.xml | 52 +++- jvm-packages/scalastyle-config.xml | 291 ++++++++++++++++++ .../dmlc/xgboost4j/demo/BasicWalkThrough.java | 169 +++++----- .../xgboost4j/demo/BoostFromPrediction.java | 67 ++-- .../dmlc/xgboost4j/demo/CrossValidation.java | 4 +- .../dmlc/xgboost4j/demo/CustomObjective.java | 257 ++++++++-------- .../dmlc/xgboost4j/demo/ExternalMemory.java | 64 ++-- .../demo/GeneralizedLinearModel.java | 77 ++--- .../xgboost4j/demo/PredictFirstNtree.java | 69 +++-- .../xgboost4j/demo/PredictLeafIndices.java | 69 +++-- .../dmlc/xgboost4j/demo/util/CustomEval.java | 60 ++-- .../dmlc/xgboost4j/demo/util/DataLoader.java | 187 +++++------ .../main/java/org/dmlc/xgboost4j/Booster.java | 11 +- .../main/java/org/dmlc/xgboost4j/DMatrix.java | 15 +- .../java/org/dmlc/xgboost4j/IEvaluation.java | 4 +- .../java/org/dmlc/xgboost4j/IObjective.java | 4 +- .../org/dmlc/xgboost4j/JNIErrorHandle.java | 6 +- .../org/dmlc/xgboost4j/JavaBoosterImpl.java | 21 +- .../org/dmlc/xgboost4j/NativeLibLoader.java | 18 +- .../main/java/org/dmlc/xgboost4j/XGBoost.java | 4 +- .../java/org/dmlc/xgboost4j/XGBoostError.java | 4 +- .../org/dmlc/xgboost4j/scala/Booster.scala | 24 +- .../org/dmlc/xgboost4j/scala/DMatrix.scala | 16 + .../org/dmlc/xgboost4j/scala/EvalTrait.scala | 38 +++ .../dmlc/xgboost4j/scala/ObjectiveTrait.scala | 30 ++ .../xgboost4j/scala/ScalaBoosterImpl.scala | 22 +- .../org/dmlc/xgboost4j/scala/XGBoost.scala | 49 ++- tests/travis/run_test.sh | 1 - 30 files changed, 1252 insertions(+), 583 deletions(-) create mode 100644 jvm-packages/checkstyle-suppressions.xml create mode 100644 jvm-packages/checkstyle.xml create mode 100644 jvm-packages/scalastyle-config.xml create mode 100644 jvm-packages/xgboost4j/src/main/scala/org/dmlc/xgboost4j/scala/EvalTrait.scala create mode 100644 jvm-packages/xgboost4j/src/main/scala/org/dmlc/xgboost4j/scala/ObjectiveTrait.scala diff --git a/jvm-packages/checkstyle-suppressions.xml b/jvm-packages/checkstyle-suppressions.xml new file mode 100644 index 000000000..85cc27e12 --- /dev/null +++ b/jvm-packages/checkstyle-suppressions.xml @@ -0,0 +1,33 @@ + + + + + + + + + diff --git a/jvm-packages/checkstyle.xml b/jvm-packages/checkstyle.xml new file mode 100644 index 000000000..9583ec282 --- /dev/null +++ b/jvm-packages/checkstyle.xml @@ -0,0 +1,169 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/jvm-packages/pom.xml b/jvm-packages/pom.xml index 33e634188..5ec221175 100644 --- a/jvm-packages/pom.xml +++ b/jvm-packages/pom.xml @@ -23,6 +23,47 @@ + + org.scalastyle + scalastyle-maven-plugin + 0.8.0 + + false + true + true + ${basedir}/src/main/scala + ${basedir}/src/test/scala + scalastyle-config.xml + UTF-8 + + + + checkstyle + validate + + check + + + + + + org.apache.maven.plugins + maven-checkstyle-plugin + 2.17 + + checkstyle.xml + true + + + + checkstyle + validate + + check + + + + net.alchim31.maven scala-maven-plugin @@ -53,6 +94,7 @@ org.apache.maven.plugins maven-surefire-plugin + 2.19.1 -Djava.library.path=lib/ @@ -65,16 +107,6 @@ commons-logging 1.2 - - org.scala-lang - scala-compiler - ${scala.version} - - - org.scala-lang - scala-library - ${scala.version} - org.scalatest scalatest_${scala.binary.version} diff --git a/jvm-packages/scalastyle-config.xml b/jvm-packages/scalastyle-config.xml new file mode 100644 index 000000000..27bb4fa8a --- /dev/null +++ b/jvm-packages/scalastyle-config.xml @@ -0,0 +1,291 @@ + + + + + Scalastyle standard configuration + + + + + + + + + + + + + + + + + + + + + + + + true + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + ARROW, EQUALS, ELSE, TRY, CATCH, FINALLY, LARROW, RARROW + + + + + + ARROW, EQUALS, COMMA, COLON, IF, ELSE, DO, WHILE, FOR, MATCH, TRY, CATCH, FINALLY, LARROW, RARROW + + + + + + + + + ^FunSuite[A-Za-z]*$ + Tests must extend org.apache.spark.SparkFunSuite instead. + + + + + ^println$ + + + + + @VisibleForTesting + + + + + Runtime\.getRuntime\.addShutdownHook + + + + + mutable\.SynchronizedBuffer + + + + + Class\.forName + + + + + + JavaConversions + Instead of importing implicits in scala.collection.JavaConversions._, import + scala.collection.JavaConverters._ and use .asScala / .asJava methods + + + + + java,scala,3rdParty,spark + javax?\..* + scala\..* + (?!org\.apache\.spark\.).* + org\.apache\.spark\..* + + + + + + COMMA + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 800> + + + + + 30 + + + + + 10 + + + + + 50 + + + + + + + + + + + -1,0,1,2,3 + + + diff --git a/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BasicWalkThrough.java b/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BasicWalkThrough.java index ab01c763a..92d4d7eed 100644 --- a/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BasicWalkThrough.java +++ b/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BasicWalkThrough.java @@ -1,10 +1,10 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software @@ -26,92 +26,93 @@ import java.util.HashMap; /** * a simple example of java wrapper for xgboost + * * @author hzx */ public class BasicWalkThrough { - public static boolean checkPredicts(float[][] fPredicts, float[][] sPredicts) { - if(fPredicts.length != sPredicts.length) { - return false; - } - - for(int i=0; i params = new HashMap(); - params.put("eta", 1.0); - params.put("max_depth", 2); - params.put("silent", 1); - params.put("objective", "binary:logistic"); - - - HashMap watches = new HashMap(); - watches.put("train", trainMat); - watches.put("test", testMat); - - //set round - int round = 2; - - //train a boost model - Booster booster = XGBoost.train(params, trainMat, round, watches, null, null); - - //predict - float[][] predicts = booster.predict(testMat); - - //save model to modelPath - File file = new File("./model"); - if(!file.exists()) { - file.mkdirs(); - } - - String modelPath = "./model/xgb.model"; - booster.saveModel(modelPath); - - //dump model - booster.dumpModel("./model/dump.raw.txt", false); - - //dump model with feature map - booster.dumpModel("./model/dump.nice.txt", "../../demo/data/featmap.txt", false); - - //save dmatrix into binary buffer - testMat.saveBinary("./model/dtest.buffer"); - - //reload model and data - Booster booster2 = XGBoost.loadBoostModel(params, "./model/xgb.model"); - DMatrix testMat2 = new DMatrix("./model/dtest.buffer"); - float[][] predicts2 = booster2.predict(testMat2); - - - //check the two predicts - System.out.println(checkPredicts(predicts, predicts2)); - - System.out.println("start build dmatrix from csr sparse data ..."); - //build dmatrix from CSR Sparse Matrix - DataLoader.CSRSparseData spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train"); - - DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data, - DMatrix.SparseType.CSR); - trainMat2.setLabel(spData.labels); - - //specify watchList - HashMap watches2 = new HashMap(); - watches2.put("train", trainMat2); - watches2.put("test", testMat2); - Booster booster3 = XGBoost.train(params, trainMat2, round, watches2, null, null); - float[][] predicts3 = booster3.predict(testMat2); - - //check predicts - System.out.println(checkPredicts(predicts, predicts3)); + for (int i = 0; i < fPredicts.length; i++) { + if (!Arrays.equals(fPredicts[i], sPredicts[i])) { + return false; + } } + + return true; + } + + + public static void main(String[] args) throws IOException, XGBoostError { + // load file from text file, also binary buffer generated by xgboost4j + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + + HashMap params = new HashMap(); + params.put("eta", 1.0); + params.put("max_depth", 2); + params.put("silent", 1); + params.put("objective", "binary:logistic"); + + + HashMap watches = new HashMap(); + watches.put("train", trainMat); + watches.put("test", testMat); + + //set round + int round = 2; + + //train a boost model + Booster booster = XGBoost.train(params, trainMat, round, watches, null, null); + + //predict + float[][] predicts = booster.predict(testMat); + + //save model to modelPath + File file = new File("./model"); + if (!file.exists()) { + file.mkdirs(); + } + + String modelPath = "./model/xgb.model"; + booster.saveModel(modelPath); + + //dump model + booster.dumpModel("./model/dump.raw.txt", false); + + //dump model with feature map + booster.dumpModel("./model/dump.nice.txt", "../../demo/data/featmap.txt", false); + + //save dmatrix into binary buffer + testMat.saveBinary("./model/dtest.buffer"); + + //reload model and data + Booster booster2 = XGBoost.loadBoostModel(params, "./model/xgb.model"); + DMatrix testMat2 = new DMatrix("./model/dtest.buffer"); + float[][] predicts2 = booster2.predict(testMat2); + + + //check the two predicts + System.out.println(checkPredicts(predicts, predicts2)); + + System.out.println("start build dmatrix from csr sparse data ..."); + //build dmatrix from CSR Sparse Matrix + DataLoader.CSRSparseData spData = DataLoader.loadSVMFile("../../demo/data/agaricus.txt.train"); + + DMatrix trainMat2 = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data, + DMatrix.SparseType.CSR); + trainMat2.setLabel(spData.labels); + + //specify watchList + HashMap watches2 = new HashMap(); + watches2.put("train", trainMat2); + watches2.put("test", testMat2); + Booster booster3 = XGBoost.train(params, trainMat2, round, watches2, null, null); + float[][] predicts3 = booster3.predict(testMat2); + + //check predicts + System.out.println(checkPredicts(predicts, predicts3)); + } } diff --git a/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BoostFromPrediction.java b/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BoostFromPrediction.java index 9713ccc5b..cdfee20c5 100644 --- a/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BoostFromPrediction.java +++ b/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BoostFromPrediction.java @@ -1,10 +1,10 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software @@ -21,38 +21,39 @@ import java.util.HashMap; /** * example for start from a initial base prediction + * * @author hzx */ public class BoostFromPrediction { - public static void main(String[] args) throws XGBoostError { - System.out.println("start running example to start from a initial prediction"); - - // load file from text file, also binary buffer generated by xgboost4j - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); - - //specify parameters - HashMap params = new HashMap(); - params.put("eta", 1.0); - params.put("max_depth", 2); - params.put("silent", 1); - params.put("objective", "binary:logistic"); - - //specify watchList - HashMap watches = new HashMap(); - watches.put("train", trainMat); - watches.put("test", testMat); - - //train xgboost for 1 round - Booster booster = XGBoost.train(params, trainMat, 1, watches, null, null); - - float[][] trainPred = booster.predict(trainMat, true); - float[][] testPred = booster.predict(testMat, true); - - trainMat.setBaseMargin(trainPred); - testMat.setBaseMargin(testPred); - - System.out.println("result of running from initial prediction"); - Booster booster2 = XGBoost.train(params, trainMat, 1, watches, null, null); - } + public static void main(String[] args) throws XGBoostError { + System.out.println("start running example to start from a initial prediction"); + + // load file from text file, also binary buffer generated by xgboost4j + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + + //specify parameters + HashMap params = new HashMap(); + params.put("eta", 1.0); + params.put("max_depth", 2); + params.put("silent", 1); + params.put("objective", "binary:logistic"); + + //specify watchList + HashMap watches = new HashMap(); + watches.put("train", trainMat); + watches.put("test", testMat); + + //train xgboost for 1 round + Booster booster = XGBoost.train(params, trainMat, 1, watches, null, null); + + float[][] trainPred = booster.predict(trainMat, true); + float[][] testPred = booster.predict(testMat, true); + + trainMat.setBaseMargin(trainPred); + testMat.setBaseMargin(testPred); + + System.out.println("result of running from initial prediction"); + Booster booster2 = XGBoost.train(params, trainMat, 1, watches, null, null); + } } diff --git a/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CrossValidation.java b/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CrossValidation.java index 962d3c355..9793644b7 100644 --- a/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CrossValidation.java +++ b/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CrossValidation.java @@ -1,10 +1,10 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software diff --git a/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CustomObjective.java b/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CustomObjective.java index 1ec592df0..9217f1d42 100644 --- a/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CustomObjective.java +++ b/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CustomObjective.java @@ -1,10 +1,10 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software @@ -29,137 +29,142 @@ import java.util.List; * this may make buildin evalution metric not function properly * for example, we are doing logistic loss, the prediction is score before logistic transformation * he buildin evaluation error assumes input is after logistic transformation - * Take this in mind when you use the customization, and maybe you need write customized evaluation function + * Take this in mind when you use the customization, and maybe you need write customized evaluation + * function + * * @author hzx */ public class CustomObjective { + /** + * loglikelihoode loss obj function + */ + public static class LogRegObj implements IObjective { + private static final Log logger = LogFactory.getLog(LogRegObj.class); + /** - * loglikelihoode loss obj function + * simple sigmoid func + * + * @param input + * @return Note: this func is not concern about numerical stability, only used as example */ - public static class LogRegObj implements IObjective { - private static final Log logger = LogFactory.getLog(LogRegObj.class); - - /** - * simple sigmoid func - * @param input - * @return - * Note: this func is not concern about numerical stability, only used as example - */ - public float sigmoid(float input) { - float val = (float) (1/(1+Math.exp(-input))); - return val; - } - - public float[][] transform(float[][] predicts) { - int nrow = predicts.length; - float[][] transPredicts = new float[nrow][1]; - - for(int i=0; i getGradient(float[][] predicts, org.dmlc.xgboost4j.DMatrix dtrain) { - int nrow = predicts.length; - List gradients = new ArrayList(); - float[] labels; - try { - labels = dtrain.getLabel(); - } catch (XGBoostError ex) { - logger.error(ex); - return null; - } - float[] grad = new float[nrow]; - float[] hess = new float[nrow]; - - float[][] transPredicts = transform(predicts); - - for(int i=0; i0) { - error++; - } - else if(labels[i]==1f && predicts[i][0]<=0) { - error++; - } - } - - return error/labels.length; - } - } - - public static void main(String[] args) throws XGBoostError { - //load train mat (svmlight format) - org.dmlc.xgboost4j.DMatrix trainMat = new org.dmlc.xgboost4j.DMatrix("../../demo/data/agaricus.txt.train"); - //load valid mat (svmlight format) - org.dmlc.xgboost4j.DMatrix testMat = new org.dmlc.xgboost4j.DMatrix("../../demo/data/agaricus.txt.test"); - - HashMap params = new HashMap(); - params.put("eta", 1.0); - params.put("max_depth", 2); - params.put("silent", 1); + public float[][] transform(float[][] predicts) { + int nrow = predicts.length; + float[][] transPredicts = new float[nrow][1]; - - //set round - int round = 2; - - //specify watchList - HashMap watches = new HashMap(); - watches.put("train", trainMat); - watches.put("test", testMat); - - //user define obj and eval - IObjective obj = new LogRegObj(); - IEvaluation eval = new EvalError(); - - //train a booster - System.out.println("begin to train the booster model"); - Booster booster = XGBoost.train(params, trainMat, round, watches, obj, eval); + for (int i = 0; i < nrow; i++) { + transPredicts[i][0] = sigmoid(predicts[i][0]); + } + + return transPredicts; } + + @Override + public List getGradient(float[][] predicts, org.dmlc.xgboost4j.DMatrix dtrain) { + int nrow = predicts.length; + List gradients = new ArrayList(); + float[] labels; + try { + labels = dtrain.getLabel(); + } catch (XGBoostError ex) { + logger.error(ex); + return null; + } + float[] grad = new float[nrow]; + float[] hess = new float[nrow]; + + float[][] transPredicts = transform(predicts); + + for (int i = 0; i < nrow; i++) { + float predict = transPredicts[i][0]; + grad[i] = predict - labels[i]; + hess[i] = predict * (1 - predict); + } + + gradients.add(grad); + gradients.add(hess); + return gradients; + } + } + + /** + * user defined eval function. + * NOTE: when you do customized loss function, the default prediction value is margin + * this may make buildin evalution metric not function properly + * for example, we are doing logistic loss, the prediction is score before logistic transformation + * the buildin evaluation error assumes input is after logistic transformation + * Take this in mind when you use the customization, and maybe you need write customized + * evaluation function + */ + public static class EvalError implements IEvaluation { + private static final Log logger = LogFactory.getLog(EvalError.class); + + String evalMetric = "custom_error"; + + public EvalError() { + } + + @Override + public String getMetric() { + return evalMetric; + } + + @Override + public float eval(float[][] predicts, org.dmlc.xgboost4j.DMatrix dmat) { + float error = 0f; + float[] labels; + try { + labels = dmat.getLabel(); + } catch (XGBoostError ex) { + logger.error(ex); + return -1f; + } + int nrow = predicts.length; + for (int i = 0; i < nrow; i++) { + if (labels[i] == 0f && predicts[i][0] > 0) { + error++; + } else if (labels[i] == 1f && predicts[i][0] <= 0) { + error++; + } + } + + return error / labels.length; + } + } + + public static void main(String[] args) throws XGBoostError { + //load train mat (svmlight format) + org.dmlc.xgboost4j.DMatrix trainMat = + new org.dmlc.xgboost4j.DMatrix("../../demo/data/agaricus.txt.train"); + //load valid mat (svmlight format) + org.dmlc.xgboost4j.DMatrix testMat = + new org.dmlc.xgboost4j.DMatrix("../../demo/data/agaricus.txt.test"); + + HashMap params = new HashMap(); + params.put("eta", 1.0); + params.put("max_depth", 2); + params.put("silent", 1); + + + //set round + int round = 2; + + //specify watchList + HashMap watches = + new HashMap(); + watches.put("train", trainMat); + watches.put("test", testMat); + + //user define obj and eval + IObjective obj = new LogRegObj(); + IEvaluation eval = new EvalError(); + + //train a booster + System.out.println("begin to train the booster model"); + Booster booster = XGBoost.train(params, trainMat, round, watches, obj, eval); + } } diff --git a/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/ExternalMemory.java b/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/ExternalMemory.java index bd0cec906..fe826a2e4 100644 --- a/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/ExternalMemory.java +++ b/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/ExternalMemory.java @@ -1,10 +1,10 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software @@ -21,36 +21,38 @@ import java.util.HashMap; /** * simple example for using external memory version + * * @author hzx */ public class ExternalMemory { - public static void main(String[] args) throws XGBoostError { - //this is the only difference, add a # followed by a cache prefix name - //several cache file with the prefix will be generated - //currently only support convert from libsvm file - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache"); - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache"); - - //specify parameters - HashMap params = new HashMap(); - params.put("eta", 1.0); - params.put("max_depth", 2); - params.put("silent", 1); - params.put("objective", "binary:logistic"); - - //performance notice: set nthread to be the number of your real cpu - //some cpu offer two threads per core, for example, a 4 core cpu with 8 threads, in such case set nthread=4 - //param.put("nthread", num_real_cpu); - - //specify watchList - HashMap watches = new HashMap(); - watches.put("train", trainMat); - watches.put("test", testMat); - - //set round - int round = 2; - - //train a boost model - Booster booster = XGBoost.train(params, trainMat, round, watches, null, null); - } + public static void main(String[] args) throws XGBoostError { + //this is the only difference, add a # followed by a cache prefix name + //several cache file with the prefix will be generated + //currently only support convert from libsvm file + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train#dtrain.cache"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test#dtest.cache"); + + //specify parameters + HashMap params = new HashMap(); + params.put("eta", 1.0); + params.put("max_depth", 2); + params.put("silent", 1); + params.put("objective", "binary:logistic"); + + //performance notice: set nthread to be the number of your real cpu + //some cpu offer two threads per core, for example, a 4 core cpu with 8 threads, in such case + // set nthread=4 + //param.put("nthread", num_real_cpu); + + //specify watchList + HashMap watches = new HashMap(); + watches.put("train", trainMat); + watches.put("test", testMat); + + //set round + int round = 2; + + //train a boost model + Booster booster = XGBoost.train(params, trainMat, round, watches, null, null); + } } diff --git a/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/GeneralizedLinearModel.java b/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/GeneralizedLinearModel.java index b3800373d..5f795aa13 100644 --- a/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/GeneralizedLinearModel.java +++ b/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/GeneralizedLinearModel.java @@ -1,10 +1,10 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software @@ -23,44 +23,45 @@ import java.util.HashMap; /** * this is an example of fit generalized linear model in xgboost * basically, we are using linear model, instead of tree for our boosters + * * @author hzx */ public class GeneralizedLinearModel { - public static void main(String[] args) throws XGBoostError { - // load file from text file, also binary buffer generated by xgboost4j - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); - - //specify parameters - //change booster to gblinear, so that we are fitting a linear model - // alpha is the L1 regularizer - //lambda is the L2 regularizer - //you can also set lambda_bias which is L2 regularizer on the bias term - HashMap params = new HashMap(); - params.put("alpha", 0.0001); - params.put("silent", 1); - params.put("objective", "binary:logistic"); - params.put("booster", "gblinear"); + public static void main(String[] args) throws XGBoostError { + // load file from text file, also binary buffer generated by xgboost4j + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); - //normally, you do not need to set eta (step_size) - //XGBoost uses a parallel coordinate descent algorithm (shotgun), - //there could be affection on convergence with parallelization on certain cases - //setting eta to be smaller value, e.g 0.5 can make the optimization more stable - //param.put("eta", "0.5"); - - - //specify watchList - HashMap watches = new HashMap(); - watches.put("train", trainMat); - watches.put("test", testMat); - - //train a booster - int round = 4; - Booster booster = XGBoost.train(params, trainMat, round, watches, null, null); - - float[][] predicts = booster.predict(testMat); - - CustomEval eval = new CustomEval(); - System.out.println("error=" + eval.eval(predicts, testMat)); - } + //specify parameters + //change booster to gblinear, so that we are fitting a linear model + // alpha is the L1 regularizer + //lambda is the L2 regularizer + //you can also set lambda_bias which is L2 regularizer on the bias term + HashMap params = new HashMap(); + params.put("alpha", 0.0001); + params.put("silent", 1); + params.put("objective", "binary:logistic"); + params.put("booster", "gblinear"); + + //normally, you do not need to set eta (step_size) + //XGBoost uses a parallel coordinate descent algorithm (shotgun), + //there could be affection on convergence with parallelization on certain cases + //setting eta to be smaller value, e.g 0.5 can make the optimization more stable + //param.put("eta", "0.5"); + + + //specify watchList + HashMap watches = new HashMap(); + watches.put("train", trainMat); + watches.put("test", testMat); + + //train a booster + int round = 4; + Booster booster = XGBoost.train(params, trainMat, round, watches, null, null); + + float[][] predicts = booster.predict(testMat); + + CustomEval eval = new CustomEval(); + System.out.println("error=" + eval.eval(predicts, testMat)); + } } diff --git a/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictFirstNtree.java b/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictFirstNtree.java index 302eb330b..965707d2b 100644 --- a/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictFirstNtree.java +++ b/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictFirstNtree.java @@ -1,10 +1,10 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software @@ -22,41 +22,42 @@ import java.util.HashMap; /** * predict first ntree + * * @author hzx */ -public class PredictFirstNtree { - public static void main(String[] args) throws XGBoostError { - // load file from text file, also binary buffer generated by xgboost4j - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); - - //specify parameters - HashMap params = new HashMap(); +public class PredictFirstNtree { + public static void main(String[] args) throws XGBoostError { + // load file from text file, also binary buffer generated by xgboost4j + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); - params.put("eta", 1.0); - params.put("max_depth", 2); - params.put("silent", 1); - params.put("objective", "binary:logistic"); + //specify parameters + HashMap params = new HashMap(); - - //specify watchList - HashMap watches = new HashMap(); - watches.put("train", trainMat); - watches.put("test", testMat); + params.put("eta", 1.0); + params.put("max_depth", 2); + params.put("silent", 1); + params.put("objective", "binary:logistic"); - - //train a booster - int round = 3; - Booster booster = XGBoost.train(params, trainMat, round, watches, null, null); - - //predict use 1 tree - float[][] predicts1 = booster.predict(testMat, false, 1); - //by default all trees are used to do predict - float[][] predicts2 = booster.predict(testMat); - - //use a simple evaluation class to check error result - CustomEval eval = new CustomEval(); - System.out.println("error of predicts1: " + eval.eval(predicts1, testMat)); - System.out.println("error of predicts2: " + eval.eval(predicts2, testMat)); - } + + //specify watchList + HashMap watches = new HashMap(); + watches.put("train", trainMat); + watches.put("test", testMat); + + + //train a booster + int round = 3; + Booster booster = XGBoost.train(params, trainMat, round, watches, null, null); + + //predict use 1 tree + float[][] predicts1 = booster.predict(testMat, false, 1); + //by default all trees are used to do predict + float[][] predicts2 = booster.predict(testMat); + + //use a simple evaluation class to check error result + CustomEval eval = new CustomEval(); + System.out.println("error of predicts1: " + eval.eval(predicts1, testMat)); + System.out.println("error of predicts2: " + eval.eval(predicts2, testMat)); + } } diff --git a/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictLeafIndices.java b/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictLeafIndices.java index ab23526d3..552d3ceec 100644 --- a/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictLeafIndices.java +++ b/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictLeafIndices.java @@ -1,10 +1,10 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software @@ -22,41 +22,42 @@ import java.util.HashMap; /** * predict leaf indices + * * @author hzx */ public class PredictLeafIndices { - public static void main(String[] args) throws XGBoostError { - // load file from text file, also binary buffer generated by xgboost4j - DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); - DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); - - //specify parameters - HashMap params = new HashMap(); - params.put("eta", 1.0); - params.put("max_depth", 2); - params.put("silent", 1); - params.put("objective", "binary:logistic"); - - //specify watchList - HashMap watches = new HashMap(); - watches.put("train", trainMat); - watches.put("test", testMat); + public static void main(String[] args) throws XGBoostError { + // load file from text file, also binary buffer generated by xgboost4j + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); - - //train a booster - int round = 3; - Booster booster = XGBoost.train(params, trainMat, round, watches, null, null); - - //predict using first 2 tree - float[][] leafindex = booster.predict(testMat, 2, true); - for(float[] leafs : leafindex) { - System.out.println(Arrays.toString(leafs)); - } - - //predict all trees - leafindex = booster.predict(testMat, 0, true); - for(float[] leafs : leafindex) { - System.out.println(Arrays.toString(leafs)); - } + //specify parameters + HashMap params = new HashMap(); + params.put("eta", 1.0); + params.put("max_depth", 2); + params.put("silent", 1); + params.put("objective", "binary:logistic"); + + //specify watchList + HashMap watches = new HashMap(); + watches.put("train", trainMat); + watches.put("test", testMat); + + + //train a booster + int round = 3; + Booster booster = XGBoost.train(params, trainMat, round, watches, null, null); + + //predict using first 2 tree + float[][] leafindex = booster.predict(testMat, 2, true); + for (float[] leafs : leafindex) { + System.out.println(Arrays.toString(leafs)); } + + //predict all trees + leafindex = booster.predict(testMat, 0, true); + for (float[] leafs : leafindex) { + System.out.println(Arrays.toString(leafs)); + } + } } diff --git a/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/CustomEval.java b/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/CustomEval.java index bf20bde4a..b80ac3ff7 100644 --- a/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/CustomEval.java +++ b/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/CustomEval.java @@ -1,10 +1,10 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software @@ -23,38 +23,38 @@ import org.dmlc.xgboost4j.XGBoostError; /** * a util evaluation class for examples + * * @author hzx */ public class CustomEval implements IEvaluation { - private static final Log logger = LogFactory.getLog(CustomEval.class); + private static final Log logger = LogFactory.getLog(CustomEval.class); - String evalMetric = "custom_error"; - - @Override - public String getMetric() { - return evalMetric; + String evalMetric = "custom_error"; + + @Override + public String getMetric() { + return evalMetric; + } + + @Override + public float eval(float[][] predicts, DMatrix dmat) { + float error = 0f; + float[] labels; + try { + labels = dmat.getLabel(); + } catch (XGBoostError ex) { + logger.error(ex); + return -1f; + } + int nrow = predicts.length; + for (int i = 0; i < nrow; i++) { + if (labels[i] == 0f && predicts[i][0] > 0.5) { + error++; + } else if (labels[i] == 1f && predicts[i][0] <= 0.5) { + error++; + } } - @Override - public float eval(float[][] predicts, DMatrix dmat) { - float error = 0f; - float[] labels; - try { - labels = dmat.getLabel(); - } catch (XGBoostError ex) { - logger.error(ex); - return -1f; - } - int nrow = predicts.length; - for(int i=0; i0.5) { - error++; - } - else if(labels[i]==1f && predicts[i][0]<=0.5) { - error++; - } - } - - return error/labels.length; - } + return error / labels.length; + } } diff --git a/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/DataLoader.java b/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/DataLoader.java index d146f4842..ca5f7def5 100644 --- a/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/DataLoader.java +++ b/jvm-packages/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/DataLoader.java @@ -1,10 +1,10 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software @@ -23,100 +23,101 @@ import java.util.List; /** * util class for loading data + * * @author hzx */ public class DataLoader { - public static class DenseData { - public float[] labels; - public float[] data; - public int nrow; - public int ncol; + public static class DenseData { + public float[] labels; + public float[] data; + public int nrow; + public int ncol; + } + + public static class CSRSparseData { + public float[] labels; + public float[] data; + public long[] rowHeaders; + public int[] colIndex; + } + + public static DenseData loadCSVFile(String filePath) throws IOException { + DenseData denseData = new DenseData(); + + File f = new File(filePath); + FileInputStream in = new FileInputStream(f); + BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8")); + + denseData.nrow = 0; + denseData.ncol = -1; + String line; + List tlabels = new ArrayList<>(); + List tdata = new ArrayList<>(); + + while ((line = reader.readLine()) != null) { + String[] items = line.trim().split(","); + if (items.length == 0) { + continue; + } + denseData.nrow++; + if (denseData.ncol == -1) { + denseData.ncol = items.length - 1; + } + + tlabels.add(Float.valueOf(items[items.length - 1])); + for (int i = 0; i < items.length - 1; i++) { + tdata.add(Float.valueOf(items[i])); + } } - - public static class CSRSparseData { - public float[] labels; - public float[] data; - public long[] rowHeaders; - public int[] colIndex; - } - - public static DenseData loadCSVFile(String filePath) throws FileNotFoundException, UnsupportedEncodingException, IOException { - DenseData denseData = new DenseData(); - - File f = new File(filePath); - FileInputStream in = new FileInputStream(f); - BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8")); - - denseData.nrow = 0; - denseData.ncol = -1; - String line; - List tlabels = new ArrayList<>(); - List tdata = new ArrayList<>(); - - while((line=reader.readLine()) != null) { - String[] items = line.trim().split(","); - if(items.length==0) { - continue; - } - denseData.nrow++; - if(denseData.ncol == -1) { - denseData.ncol = items.length - 1; - } - - tlabels.add(Float.valueOf(items[items.length-1])); - for(int i=0; i tlabels = new ArrayList<>(); - List tdata = new ArrayList<>(); - List theaders = new ArrayList<>(); - List tindex = new ArrayList<>(); - - File f = new File(filePath); - FileInputStream in = new FileInputStream(f); - BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8")); - - String line; - long rowheader = 0; - theaders.add(rowheader); - while((line=reader.readLine()) != null) { - String[] items = line.trim().split(" "); - if(items.length==0) { - continue; - } - - rowheader += items.length - 1; - theaders.add(rowheader); - tlabels.add(Float.valueOf(items[0])); - - for(int i=1; i tlabels = new ArrayList<>(); + List tdata = new ArrayList<>(); + List theaders = new ArrayList<>(); + List tindex = new ArrayList<>(); + + File f = new File(filePath); + FileInputStream in = new FileInputStream(f); + BufferedReader reader = new BufferedReader(new InputStreamReader(in, "UTF-8")); + + String line; + long rowheader = 0; + theaders.add(rowheader); + while ((line = reader.readLine()) != null) { + String[] items = line.trim().split(" "); + if (items.length == 0) { + continue; + } + + rowheader += items.length - 1; + theaders.add(rowheader); + tlabels.add(Float.valueOf(items[0])); + + for (int i = 1; i < items.length; i++) { + String[] tup = items[i].split(":"); + assert tup.length == 2; + + tdata.add(Float.valueOf(tup[1])); + tindex.add(Integer.valueOf(tup[0])); + } } + + spData.labels = ArrayUtils.toPrimitive(tlabels.toArray(new Float[tlabels.size()])); + spData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata.size()])); + spData.colIndex = ArrayUtils.toPrimitive(tindex.toArray(new Integer[tindex.size()])); + spData.rowHeaders = ArrayUtils.toPrimitive(theaders.toArray(new Long[theaders.size()])); + + return spData; + } } diff --git a/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/Booster.java b/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/Booster.java index 2d6ac8475..626e15006 100644 --- a/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/Booster.java +++ b/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/Booster.java @@ -99,10 +99,10 @@ public interface Booster { * Predict with data * @param data dmatrix storing the input * @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees). - * @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees), nsample = data.numRow - with each record indicating the predicted leaf index of each sample in each tree. - Note that the leaf index of a tree is unique per tree, so you may find leaf 1 - in both tree 1 and tree 0. + * @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees), + * nsample = data.numRow with each record indicating the predicted leaf index of + * each sample in each tree. Note that the leaf index of a tree is unique per + * tree, so you may find leaf 1 in both tree 1 and tree 0. * @return predict result * @throws XGBoostError native error */ @@ -131,7 +131,8 @@ public interface Booster { * @param withStats bool * Controls whether the split statistics are output. */ - void dumpModel(String modelPath, String featureMap, boolean withStats) throws IOException, XGBoostError; + void dumpModel(String modelPath, String featureMap, boolean withStats) + throws IOException, XGBoostError; /** * get importance of each feature diff --git a/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/DMatrix.java b/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/DMatrix.java index ab0a5694f..53cf365d3 100644 --- a/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/DMatrix.java +++ b/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/DMatrix.java @@ -1,10 +1,10 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software @@ -32,7 +32,7 @@ public class DMatrix { //load native library static { try { - NativeLibLoader.InitXgboost(); + NativeLibLoader.initXgBoost(); } catch (IOException ex) { logger.error("load native library failed."); logger.error(ex); @@ -84,8 +84,6 @@ public class DMatrix { /** * used for DMatrix slice - * - * @param handle */ protected DMatrix(long handle) { this.handle = handle; @@ -216,8 +214,6 @@ public class DMatrix { /** * save DMatrix to filePath - * - * @param filePath file path */ public void saveBinary(String filePath) { XgboostJNI.XGDMatrixSaveBinary(handle, filePath, 1); @@ -225,8 +221,6 @@ public class DMatrix { /** * Get the handle - * - * @return native handler id */ public long getHandle() { return handle; @@ -234,9 +228,6 @@ public class DMatrix { /** * flatten a mat to array - * - * @param mat - * @return */ private static float[] flatten(float[][] mat) { int size = 0; diff --git a/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/IEvaluation.java b/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/IEvaluation.java index 777b8e8bb..1d1884463 100644 --- a/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/IEvaluation.java +++ b/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/IEvaluation.java @@ -1,10 +1,10 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software diff --git a/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/IObjective.java b/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/IObjective.java index 978ba8da5..87cd3009d 100644 --- a/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/IObjective.java +++ b/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/IObjective.java @@ -1,10 +1,10 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software diff --git a/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/JNIErrorHandle.java b/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/JNIErrorHandle.java index 4c0709ae5..d5feca00a 100644 --- a/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/JNIErrorHandle.java +++ b/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/JNIErrorHandle.java @@ -1,10 +1,10 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software @@ -30,7 +30,7 @@ class JNIErrorHandle { //load native library static { try { - NativeLibLoader.InitXgboost(); + NativeLibLoader.initXgBoost(); } catch (IOException ex) { logger.error("load native library failed."); logger.error(ex); diff --git a/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/JavaBoosterImpl.java b/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/JavaBoosterImpl.java index d7a38eff2..76027d65c 100644 --- a/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/JavaBoosterImpl.java +++ b/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/JavaBoosterImpl.java @@ -1,10 +1,10 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software @@ -38,7 +38,7 @@ class JavaBoosterImpl implements Booster { //load native library static { try { - NativeLibLoader.InitXgboost(); + NativeLibLoader.initXgBoost(); } catch (IOException ex) { logger.error("load native library failed."); logger.error(ex); @@ -80,7 +80,7 @@ class JavaBoosterImpl implements Booster { private void init(DMatrix[] dMatrixs) throws XGBoostError { long[] handles = null; if (dMatrixs != null) { - handles = dMatrixs2handles(dMatrixs); + handles = dmatrixsToHandles(dMatrixs); } long[] out = new long[1]; JNIErrorHandle.checkCall(XgboostJNI.XGBoosterCreate(handles, out)); @@ -151,7 +151,8 @@ class JavaBoosterImpl implements Booster { throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length, hess.length)); } - JNIErrorHandle.checkCall(XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad, hess)); + JNIErrorHandle.checkCall(XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad, + hess)); } /** @@ -164,9 +165,10 @@ class JavaBoosterImpl implements Booster { * @throws XGBoostError native error */ public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError { - long[] handles = dMatrixs2handles(evalMatrixs); + long[] handles = dmatrixsToHandles(evalMatrixs); String[] evalInfo = new String[1]; - JNIErrorHandle.checkCall(XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames, evalInfo)); + JNIErrorHandle.checkCall(XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames, + evalInfo)); return evalInfo[0]; } @@ -322,7 +324,8 @@ class JavaBoosterImpl implements Booster { statsFlag = 1; } String[][] modelInfos = new String[1][]; - JNIErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag, modelInfos)); + JNIErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag, + modelInfos)); return modelInfos[0]; } @@ -444,7 +447,7 @@ class JavaBoosterImpl implements Booster { * @param dmatrixs * @return handle array for input dmatrixs */ - private static long[] dMatrixs2handles(DMatrix[] dmatrixs) { + private static long[] dmatrixsToHandles(DMatrix[] dmatrixs) { long[] handles = new long[dmatrixs.length]; for (int i = 0; i < dmatrixs.length; i++) { handles[i] = dmatrixs[i].getHandle(); diff --git a/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/NativeLibLoader.java b/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/NativeLibLoader.java index b613e4153..dc8494e9c 100644 --- a/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/NativeLibLoader.java +++ b/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/NativeLibLoader.java @@ -1,10 +1,10 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software @@ -34,7 +34,7 @@ class NativeLibLoader { private static final String nativeResourcePath = "/lib/"; private static final String[] libNames = new String[]{"xgboost4j"}; - public static synchronized void InitXgboost() throws IOException { + public static synchronized void initXgBoost() throws IOException { if (!initialized) { for (String libName : libNames) { smartLoad(libName); @@ -50,14 +50,17 @@ class NativeLibLoader { * The temporary file is deleted after exiting. * Method uses String as filename because the pathname is "abstract", not system-dependent. *

- * The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to {@code path}. + * The restrictions of {@link File#createTempFile(java.lang.String, java.lang.String)} apply to + * {@code path}. * - * @param path The filename inside JAR as absolute path (beginning with '/'), e.g. /package/File.ext + * @param path The filename inside JAR as absolute path (beginning with '/'), + * e.g. /package/File.ext * @throws IOException If temporary file creation or read/write operation fails * @throws IllegalArgumentException If source file (param path) does not exist - * @throws IllegalArgumentException If the path is not absolute or if the filename is shorter than three characters + * @throws IllegalArgumentException If the path is not absolute or if the filename is shorter than + * three characters */ - private static void loadLibraryFromJar(String path) throws IOException { + private static void loadLibraryFromJar(String path) throws IOException, IllegalArgumentException{ if (!path.startsWith("/")) { throw new IllegalArgumentException("The path has to be absolute (start with '/')."); @@ -126,7 +129,6 @@ class NativeLibLoader { addNativeDir(nativePath); try { System.loadLibrary(libName); - System.out.println("======load " + libName + " successfully"); } catch (UnsatisfiedLinkError e) { try { String libraryFromJar = nativeResourcePath + System.mapLibraryName(libName); diff --git a/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/XGBoost.java index 09acb24c9..ffea50dc7 100644 --- a/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/XGBoost.java +++ b/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/XGBoost.java @@ -1,10 +1,10 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software diff --git a/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/XGBoostError.java b/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/XGBoostError.java index 48281b250..69b69bc86 100644 --- a/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/XGBoostError.java +++ b/jvm-packages/xgboost4j/src/main/java/org/dmlc/xgboost4j/XGBoostError.java @@ -1,10 +1,10 @@ /* - Copyright (c) 2014 by Contributors + Copyright (c) 2014 by Contributors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software diff --git a/jvm-packages/xgboost4j/src/main/scala/org/dmlc/xgboost4j/scala/Booster.scala b/jvm-packages/xgboost4j/src/main/scala/org/dmlc/xgboost4j/scala/Booster.scala index 6b8c21c8f..7e0c574cf 100644 --- a/jvm-packages/xgboost4j/src/main/scala/org/dmlc/xgboost4j/scala/Booster.scala +++ b/jvm-packages/xgboost4j/src/main/scala/org/dmlc/xgboost4j/scala/Booster.scala @@ -1,3 +1,19 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + package org.dmlc.xgboost4j.scala import java.io.IOException @@ -111,10 +127,10 @@ trait Booster { * * @param data dmatrix storing the input * @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees). - * @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees), nsample = data.numRow - with each record indicating the predicted leaf index of each sample in each tree. - Note that the leaf index of a tree is unique per tree, so you may find leaf 1 - in both tree 1 and tree 0. + * @param predLeaf When this option is on, the output will be a matrix of (nsample, ntrees), + * nsample = data.numRow with each record indicating the predicted leaf index of + * each sample in each tree. Note that the leaf index of a tree is unique per + * tree, so you may find leaf 1 in both tree 1 and tree 0. * @return predict result * @throws XGBoostError native error */ diff --git a/jvm-packages/xgboost4j/src/main/scala/org/dmlc/xgboost4j/scala/DMatrix.scala b/jvm-packages/xgboost4j/src/main/scala/org/dmlc/xgboost4j/scala/DMatrix.scala index 553d0469f..84d0fa9e6 100644 --- a/jvm-packages/xgboost4j/src/main/scala/org/dmlc/xgboost4j/scala/DMatrix.scala +++ b/jvm-packages/xgboost4j/src/main/scala/org/dmlc/xgboost4j/scala/DMatrix.scala @@ -1,3 +1,19 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + package org.dmlc.xgboost4j.scala import org.dmlc.xgboost4j.{DMatrix => JDMatrix, XGBoostError} diff --git a/jvm-packages/xgboost4j/src/main/scala/org/dmlc/xgboost4j/scala/EvalTrait.scala b/jvm-packages/xgboost4j/src/main/scala/org/dmlc/xgboost4j/scala/EvalTrait.scala new file mode 100644 index 000000000..f62aeba2e --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/scala/org/dmlc/xgboost4j/scala/EvalTrait.scala @@ -0,0 +1,38 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package org.dmlc.xgboost4j.scala + +import org.dmlc.xgboost4j.IEvaluation + +trait EvalTrait extends IEvaluation { + + /** + * get evaluate metric + * + * @return evalMetric + */ + def getMetric: String + + /** + * evaluate with predicts and data + * + * @param predicts predictions as array + * @param dmat data matrix to evaluate + * @return result of the metric + */ + def eval(predicts: Array[Array[Float]], dmat: DMatrix): Float +} diff --git a/jvm-packages/xgboost4j/src/main/scala/org/dmlc/xgboost4j/scala/ObjectiveTrait.scala b/jvm-packages/xgboost4j/src/main/scala/org/dmlc/xgboost4j/scala/ObjectiveTrait.scala new file mode 100644 index 000000000..1efc2c0ed --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/scala/org/dmlc/xgboost4j/scala/ObjectiveTrait.scala @@ -0,0 +1,30 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package org.dmlc.xgboost4j.scala + +import org.dmlc.xgboost4j.IObjective + +trait ObjectiveTrait extends IObjective { + /** + * user define objective function, return gradient and second order gradient + * + * @param predicts untransformed margin predicts + * @param dtrain training data + * @return List with two float array, correspond to first order grad and second order grad + */ + def getGradient(predicts: Array[Array[Float]], dtrain: DMatrix): java.util.List[Array[Float]] +} diff --git a/jvm-packages/xgboost4j/src/main/scala/org/dmlc/xgboost4j/scala/ScalaBoosterImpl.scala b/jvm-packages/xgboost4j/src/main/scala/org/dmlc/xgboost4j/scala/ScalaBoosterImpl.scala index 4cdfb43e9..1f759fda9 100644 --- a/jvm-packages/xgboost4j/src/main/scala/org/dmlc/xgboost4j/scala/ScalaBoosterImpl.scala +++ b/jvm-packages/xgboost4j/src/main/scala/org/dmlc/xgboost4j/scala/ScalaBoosterImpl.scala @@ -1,3 +1,19 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + package org.dmlc.xgboost4j.scala import scala.collection.JavaConverters._ @@ -35,7 +51,8 @@ private[scala] class ScalaBoosterImpl private[xgboost4j](booster: JBooster) exte booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, iter) } - override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: IEvaluation): String = { + override def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], eval: IEvaluation): + String = { booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, eval) } @@ -51,7 +68,8 @@ private[scala] class ScalaBoosterImpl private[xgboost4j](booster: JBooster) exte booster.predict(data.jDMatrix, outPutMargin) } - override def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int): Array[Array[Float]] = { + override def predict(data: DMatrix, outPutMargin: Boolean, treeLimit: Int): + Array[Array[Float]] = { booster.predict(data.jDMatrix, outPutMargin, treeLimit) } diff --git a/jvm-packages/xgboost4j/src/main/scala/org/dmlc/xgboost4j/scala/XGBoost.scala b/jvm-packages/xgboost4j/src/main/scala/org/dmlc/xgboost4j/scala/XGBoost.scala index 9b26bb988..df0510ce8 100644 --- a/jvm-packages/xgboost4j/src/main/scala/org/dmlc/xgboost4j/scala/XGBoost.scala +++ b/jvm-packages/xgboost4j/src/main/scala/org/dmlc/xgboost4j/scala/XGBoost.scala @@ -1,30 +1,47 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + package org.dmlc.xgboost4j.scala import _root_.scala.collection.JavaConverters._ - -import org.dmlc.xgboost4j -import org.dmlc.xgboost4j.{XGBoost => JXGBoost, IEvaluation, IObjective} +import org.dmlc.xgboost4j.{IEvaluation, IObjective, XGBoost => JXGBoost} object XGBoost { - def train(params: Map[String, AnyRef], dtrain: xgboost4j.DMatrix, round: Int, - watches: Map[String, xgboost4j.DMatrix], obj: IObjective, eval: IEvaluation): Booster = { - val xgboostInJava = JXGBoost.train(params.asJava, dtrain, round, watches.asJava, obj, eval) + def train(params: Map[String, AnyRef], dtrain: DMatrix, round: Int, + watches: Map[String, DMatrix], obj: IObjective, eval: IEvaluation): Booster = { + val jWatches = watches.map{case (name, matrix) => (name, matrix.jDMatrix)} + val xgboostInJava = JXGBoost.train(params.asJava, dtrain.jDMatrix, round, jWatches.asJava, + obj, eval) new ScalaBoosterImpl(xgboostInJava) } - def crossValiation(params: Map[String, AnyRef], - data: DMatrix, - round: Int, - nfold: Int, - metrics: Array[String], - obj: IObjective, - eval: IEvaluation): Array[String] = { - JXGBoost.crossValiation(params.asJava, data.jDMatrix, round, nfold, metrics, obj, - eval) + def crossValiation( + params: Map[String, AnyRef], + data: DMatrix, + round: Int, + nfold: Int, + metrics: Array[String], + obj: EvalTrait, + eval: ObjectiveTrait): Array[String] = { + JXGBoost.crossValiation(params.asJava, data.jDMatrix, round, nfold, metrics, + obj.asInstanceOf[IObjective], eval.asInstanceOf[IEvaluation]) } - def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = { + def initBoostModel(params: Map[String, AnyRef], dMatrixs: Array[DMatrix]): Booster = { val xgboostInJava = JXGBoost.initBoostingModel(params.asJava, dMatrixs.map(_.jDMatrix)) new ScalaBoosterImpl(xgboostInJava) } diff --git a/tests/travis/run_test.sh b/tests/travis/run_test.sh index 8717e0182..6f6e276a6 100755 --- a/tests/travis/run_test.sh +++ b/tests/travis/run_test.sh @@ -76,7 +76,6 @@ if [ ${TASK} == "java_test" ]; then make jvm-packages cd jvm-packages ./create_wrap.sh - cd xgboost4j mvn clean install -DskipTests=true mvn test fi