diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/Booster.java b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/Booster.java index 51fee441e..c2058ceaa 100644 --- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/Booster.java +++ b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/Booster.java @@ -76,14 +76,16 @@ public final class Booster { * @throws org.dmlc.xgboost4j.util.XGBoostError */ public Booster(Iterable> params, String modelPath) throws XGBoostError { - long[] out = new long[1]; init(null); + if(modelPath == null) { + throw new NullPointerException("modelPath : null"); + } loadModel(modelPath); setParam("seed","0"); setParams(params); } - + private void init(DMatrix[] dMatrixs) throws XGBoostError { diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/DMatrix.java b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/DMatrix.java index 61db98a6d..1a4f4dd28 100644 --- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/DMatrix.java +++ b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/DMatrix.java @@ -55,6 +55,9 @@ public class DMatrix { * @throws org.dmlc.xgboost4j.util.XGBoostError */ public DMatrix(String dataPath) throws XGBoostError { + if(dataPath == null) { + throw new NullPointerException("dataPath: null"); + } long[] out = new long[1]; ErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromFile(dataPath, 1, out)); handle = out[0]; diff --git a/java/xgboost4j/src/test/java/org/dmlc/xgboost4j/BoosterTest.java b/java/xgboost4j/src/test/java/org/dmlc/xgboost4j/BoosterTest.java new file mode 100644 index 000000000..eb022b7e8 --- /dev/null +++ b/java/xgboost4j/src/test/java/org/dmlc/xgboost4j/BoosterTest.java @@ -0,0 +1,108 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package org.dmlc.xgboost4j; + +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import junit.framework.TestCase; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.dmlc.xgboost4j.util.Trainer; +import org.dmlc.xgboost4j.util.XGBoostError; +import org.junit.Test; + +/** + * test cases for Booster + * @author hzx + */ +public class BoosterTest { + public static class EvalError implements IEvaluation { + private static final Log logger = LogFactory.getLog(EvalError.class); + + String evalMetric = "custom_error"; + + public EvalError() { + } + + @Override + public String getMetric() { + return evalMetric; + } + + @Override + public float eval(float[][] predicts, DMatrix dmat) { + float error = 0f; + float[] labels; + try { + labels = dmat.getLabel(); + } catch (XGBoostError ex) { + logger.error(ex); + return -1f; + } + int nrow = predicts.length; + for(int i=0; i0) { + error++; + } + else if(labels[i]==1f && predicts[i][0]<=0) { + error++; + } + } + + return error/labels.length; + } + } + + @Test + public void testBoosterBasic() throws XGBoostError { + DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); + DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); + + //set params + Map paramMap = new HashMap() { + { + put("eta", 1.0); + put("max_depth", 2); + put("silent", 1); + put("objective", "binary:logistic"); + } + }; + Iterable> param = paramMap.entrySet(); + + //set watchList + List> watchs = new ArrayList<>(); + watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat)); + watchs.add(new AbstractMap.SimpleEntry<>("test", testMat)); + + //set round + int round = 2; + + //train a boost model + Booster booster = Trainer.train(param, trainMat, round, watchs, null, null); + + //predict raw output + float[][] predicts = booster.predict(testMat, true); + + //eval + IEvaluation eval = new EvalError(); + //error must be less than 0.1 + TestCase.assertTrue(eval.eval(predicts, testMat)<0.1f); + } +} diff --git a/java/xgboost4j/src/test/java/org/dmlc/xgboost4j/DMatrixTest.java b/java/xgboost4j/src/test/java/org/dmlc/xgboost4j/DMatrixTest.java new file mode 100644 index 000000000..343dd3ed9 --- /dev/null +++ b/java/xgboost4j/src/test/java/org/dmlc/xgboost4j/DMatrixTest.java @@ -0,0 +1,102 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package org.dmlc.xgboost4j; + +import java.util.Arrays; +import java.util.Random; +import junit.framework.TestCase; +import org.dmlc.xgboost4j.util.XGBoostError; +import org.junit.Test; + +/** + * test cases for DMatrix + * @author hzx + */ +public class DMatrixTest { + + @Test + public void testCreateFromFile() throws XGBoostError { + //create DMatrix from file + DMatrix dmat = new DMatrix("../../demo/data/agaricus.txt.test"); + //get label + float[] labels = dmat.getLabel(); + //check length + TestCase.assertTrue(dmat.rowNum()==labels.length); + //set weights + float[] weights = Arrays.copyOf(labels, labels.length); + dmat.setWeight(weights); + float[] dweights = dmat.getWeight(); + TestCase.assertTrue(Arrays.equals(weights, dweights)); + } + + @Test + public void testCreateFromCSR() throws XGBoostError { + //create Matrix from csr format sparse Matrix and labels + /** + * sparse matrix + * 1 0 2 3 0 + * 4 0 2 3 5 + * 3 1 2 5 0 + */ + float[] data = new float[] {1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5}; + int[] colIndex = new int[] {0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3}; + long[] rowHeaders = new long[] {0, 3, 7, 11}; + DMatrix dmat1 = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR); + //check row num + System.out.println(dmat1.rowNum()); + TestCase.assertTrue(dmat1.rowNum()==3); + //test set label + float[] label1 = new float[] {1, 0, 1}; + dmat1.setLabel(label1); + float[] label2 = dmat1.getLabel(); + TestCase.assertTrue(Arrays.equals(label1, label2)); + } + + @Test + public void testCreateFromDenseMatrix() throws XGBoostError { + //create DMatrix from 10*5 dense matrix + int nrow = 10; + int ncol = 5; + float[] data0 = new float[nrow*ncol]; + //put random nums + Random random = new Random(); + for(int i=0; i