add testcases
This commit is contained in:
commit
0fc47f5abb
@ -76,8 +76,10 @@ public final class Booster {
|
|||||||
* @throws org.dmlc.xgboost4j.util.XGBoostError
|
* @throws org.dmlc.xgboost4j.util.XGBoostError
|
||||||
*/
|
*/
|
||||||
public Booster(Iterable<Entry<String, Object>> params, String modelPath) throws XGBoostError {
|
public Booster(Iterable<Entry<String, Object>> params, String modelPath) throws XGBoostError {
|
||||||
long[] out = new long[1];
|
|
||||||
init(null);
|
init(null);
|
||||||
|
if(modelPath == null) {
|
||||||
|
throw new NullPointerException("modelPath : null");
|
||||||
|
}
|
||||||
loadModel(modelPath);
|
loadModel(modelPath);
|
||||||
setParam("seed","0");
|
setParam("seed","0");
|
||||||
setParams(params);
|
setParams(params);
|
||||||
|
|||||||
@ -55,6 +55,9 @@ public class DMatrix {
|
|||||||
* @throws org.dmlc.xgboost4j.util.XGBoostError
|
* @throws org.dmlc.xgboost4j.util.XGBoostError
|
||||||
*/
|
*/
|
||||||
public DMatrix(String dataPath) throws XGBoostError {
|
public DMatrix(String dataPath) throws XGBoostError {
|
||||||
|
if(dataPath == null) {
|
||||||
|
throw new NullPointerException("dataPath: null");
|
||||||
|
}
|
||||||
long[] out = new long[1];
|
long[] out = new long[1];
|
||||||
ErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromFile(dataPath, 1, out));
|
ErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromFile(dataPath, 1, out));
|
||||||
handle = out[0];
|
handle = out[0];
|
||||||
|
|||||||
108
java/xgboost4j/src/test/java/org/dmlc/xgboost4j/BoosterTest.java
Normal file
108
java/xgboost4j/src/test/java/org/dmlc/xgboost4j/BoosterTest.java
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.dmlc.xgboost4j;
|
||||||
|
|
||||||
|
import java.util.AbstractMap;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Map.Entry;
|
||||||
|
import junit.framework.TestCase;
|
||||||
|
import org.apache.commons.logging.Log;
|
||||||
|
import org.apache.commons.logging.LogFactory;
|
||||||
|
import org.dmlc.xgboost4j.util.Trainer;
|
||||||
|
import org.dmlc.xgboost4j.util.XGBoostError;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* test cases for Booster
|
||||||
|
* @author hzx
|
||||||
|
*/
|
||||||
|
public class BoosterTest {
|
||||||
|
public static class EvalError implements IEvaluation {
|
||||||
|
private static final Log logger = LogFactory.getLog(EvalError.class);
|
||||||
|
|
||||||
|
String evalMetric = "custom_error";
|
||||||
|
|
||||||
|
public EvalError() {
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getMetric() {
|
||||||
|
return evalMetric;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float eval(float[][] predicts, DMatrix dmat) {
|
||||||
|
float error = 0f;
|
||||||
|
float[] labels;
|
||||||
|
try {
|
||||||
|
labels = dmat.getLabel();
|
||||||
|
} catch (XGBoostError ex) {
|
||||||
|
logger.error(ex);
|
||||||
|
return -1f;
|
||||||
|
}
|
||||||
|
int nrow = predicts.length;
|
||||||
|
for(int i=0; i<nrow; i++) {
|
||||||
|
if(labels[i]==0f && predicts[i][0]>0) {
|
||||||
|
error++;
|
||||||
|
}
|
||||||
|
else if(labels[i]==1f && predicts[i][0]<=0) {
|
||||||
|
error++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return error/labels.length;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testBoosterBasic() throws XGBoostError {
|
||||||
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
|
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||||
|
|
||||||
|
//set params
|
||||||
|
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||||
|
{
|
||||||
|
put("eta", 1.0);
|
||||||
|
put("max_depth", 2);
|
||||||
|
put("silent", 1);
|
||||||
|
put("objective", "binary:logistic");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Iterable<Entry<String, Object>> param = paramMap.entrySet();
|
||||||
|
|
||||||
|
//set watchList
|
||||||
|
List<Entry<String, DMatrix>> watchs = new ArrayList<>();
|
||||||
|
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
|
||||||
|
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
|
||||||
|
|
||||||
|
//set round
|
||||||
|
int round = 2;
|
||||||
|
|
||||||
|
//train a boost model
|
||||||
|
Booster booster = Trainer.train(param, trainMat, round, watchs, null, null);
|
||||||
|
|
||||||
|
//predict raw output
|
||||||
|
float[][] predicts = booster.predict(testMat, true);
|
||||||
|
|
||||||
|
//eval
|
||||||
|
IEvaluation eval = new EvalError();
|
||||||
|
//error must be less than 0.1
|
||||||
|
TestCase.assertTrue(eval.eval(predicts, testMat)<0.1f);
|
||||||
|
}
|
||||||
|
}
|
||||||
102
java/xgboost4j/src/test/java/org/dmlc/xgboost4j/DMatrixTest.java
Normal file
102
java/xgboost4j/src/test/java/org/dmlc/xgboost4j/DMatrixTest.java
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.dmlc.xgboost4j;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Random;
|
||||||
|
import junit.framework.TestCase;
|
||||||
|
import org.dmlc.xgboost4j.util.XGBoostError;
|
||||||
|
import org.junit.Test;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* test cases for DMatrix
|
||||||
|
* @author hzx
|
||||||
|
*/
|
||||||
|
public class DMatrixTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCreateFromFile() throws XGBoostError {
|
||||||
|
//create DMatrix from file
|
||||||
|
DMatrix dmat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||||
|
//get label
|
||||||
|
float[] labels = dmat.getLabel();
|
||||||
|
//check length
|
||||||
|
TestCase.assertTrue(dmat.rowNum()==labels.length);
|
||||||
|
//set weights
|
||||||
|
float[] weights = Arrays.copyOf(labels, labels.length);
|
||||||
|
dmat.setWeight(weights);
|
||||||
|
float[] dweights = dmat.getWeight();
|
||||||
|
TestCase.assertTrue(Arrays.equals(weights, dweights));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCreateFromCSR() throws XGBoostError {
|
||||||
|
//create Matrix from csr format sparse Matrix and labels
|
||||||
|
/**
|
||||||
|
* sparse matrix
|
||||||
|
* 1 0 2 3 0
|
||||||
|
* 4 0 2 3 5
|
||||||
|
* 3 1 2 5 0
|
||||||
|
*/
|
||||||
|
float[] data = new float[] {1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5};
|
||||||
|
int[] colIndex = new int[] {0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3};
|
||||||
|
long[] rowHeaders = new long[] {0, 3, 7, 11};
|
||||||
|
DMatrix dmat1 = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR);
|
||||||
|
//check row num
|
||||||
|
System.out.println(dmat1.rowNum());
|
||||||
|
TestCase.assertTrue(dmat1.rowNum()==3);
|
||||||
|
//test set label
|
||||||
|
float[] label1 = new float[] {1, 0, 1};
|
||||||
|
dmat1.setLabel(label1);
|
||||||
|
float[] label2 = dmat1.getLabel();
|
||||||
|
TestCase.assertTrue(Arrays.equals(label1, label2));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testCreateFromDenseMatrix() throws XGBoostError {
|
||||||
|
//create DMatrix from 10*5 dense matrix
|
||||||
|
int nrow = 10;
|
||||||
|
int ncol = 5;
|
||||||
|
float[] data0 = new float[nrow*ncol];
|
||||||
|
//put random nums
|
||||||
|
Random random = new Random();
|
||||||
|
for(int i=0; i<nrow*ncol; i++) {
|
||||||
|
data0[i] = random.nextFloat();
|
||||||
|
}
|
||||||
|
|
||||||
|
//create label
|
||||||
|
float[] label0 = new float[nrow];
|
||||||
|
for(int i=0; i<nrow; i++) {
|
||||||
|
label0[i] = random.nextFloat();
|
||||||
|
}
|
||||||
|
|
||||||
|
DMatrix dmat0 = new DMatrix(data0, nrow, ncol);
|
||||||
|
dmat0.setLabel(label0);
|
||||||
|
|
||||||
|
//check
|
||||||
|
TestCase.assertTrue(dmat0.rowNum()==10);
|
||||||
|
TestCase.assertTrue(dmat0.getLabel().length==10);
|
||||||
|
|
||||||
|
//set weights for each instance
|
||||||
|
float[] weights = new float[nrow];
|
||||||
|
for(int i=0; i<nrow; i++) {
|
||||||
|
weights[i] = random.nextFloat();
|
||||||
|
}
|
||||||
|
dmat0.setWeight(weights);
|
||||||
|
|
||||||
|
TestCase.assertTrue(Arrays.equals(weights, dmat0.getWeight()));
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user