add testcases
This commit is contained in:
commit
0fc47f5abb
@ -76,14 +76,16 @@ public final class Booster {
|
||||
* @throws org.dmlc.xgboost4j.util.XGBoostError
|
||||
*/
|
||||
public Booster(Iterable<Entry<String, Object>> params, String modelPath) throws XGBoostError {
|
||||
long[] out = new long[1];
|
||||
init(null);
|
||||
if(modelPath == null) {
|
||||
throw new NullPointerException("modelPath : null");
|
||||
}
|
||||
loadModel(modelPath);
|
||||
setParam("seed","0");
|
||||
setParams(params);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
private void init(DMatrix[] dMatrixs) throws XGBoostError {
|
||||
|
||||
@ -55,6 +55,9 @@ public class DMatrix {
|
||||
* @throws org.dmlc.xgboost4j.util.XGBoostError
|
||||
*/
|
||||
public DMatrix(String dataPath) throws XGBoostError {
|
||||
if(dataPath == null) {
|
||||
throw new NullPointerException("dataPath: null");
|
||||
}
|
||||
long[] out = new long[1];
|
||||
ErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromFile(dataPath, 1, out));
|
||||
handle = out[0];
|
||||
|
||||
108
java/xgboost4j/src/test/java/org/dmlc/xgboost4j/BoosterTest.java
Normal file
108
java/xgboost4j/src/test/java/org/dmlc/xgboost4j/BoosterTest.java
Normal file
@ -0,0 +1,108 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j;
|
||||
|
||||
import java.util.AbstractMap;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import junit.framework.TestCase;
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
import org.dmlc.xgboost4j.util.XGBoostError;
|
||||
import org.junit.Test;
|
||||
|
||||
/**
|
||||
* test cases for Booster
|
||||
* @author hzx
|
||||
*/
|
||||
public class BoosterTest {
|
||||
public static class EvalError implements IEvaluation {
|
||||
private static final Log logger = LogFactory.getLog(EvalError.class);
|
||||
|
||||
String evalMetric = "custom_error";
|
||||
|
||||
public EvalError() {
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMetric() {
|
||||
return evalMetric;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float eval(float[][] predicts, DMatrix dmat) {
|
||||
float error = 0f;
|
||||
float[] labels;
|
||||
try {
|
||||
labels = dmat.getLabel();
|
||||
} catch (XGBoostError ex) {
|
||||
logger.error(ex);
|
||||
return -1f;
|
||||
}
|
||||
int nrow = predicts.length;
|
||||
for(int i=0; i<nrow; i++) {
|
||||
if(labels[i]==0f && predicts[i][0]>0) {
|
||||
error++;
|
||||
}
|
||||
else if(labels[i]==1f && predicts[i][0]<=0) {
|
||||
error++;
|
||||
}
|
||||
}
|
||||
|
||||
return error/labels.length;
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBoosterBasic() throws XGBoostError {
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
//set params
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("eta", 1.0);
|
||||
put("max_depth", 2);
|
||||
put("silent", 1);
|
||||
put("objective", "binary:logistic");
|
||||
}
|
||||
};
|
||||
Iterable<Entry<String, Object>> param = paramMap.entrySet();
|
||||
|
||||
//set watchList
|
||||
List<Entry<String, DMatrix>> watchs = new ArrayList<>();
|
||||
watchs.add(new AbstractMap.SimpleEntry<>("train", trainMat));
|
||||
watchs.add(new AbstractMap.SimpleEntry<>("test", testMat));
|
||||
|
||||
//set round
|
||||
int round = 2;
|
||||
|
||||
//train a boost model
|
||||
Booster booster = Trainer.train(param, trainMat, round, watchs, null, null);
|
||||
|
||||
//predict raw output
|
||||
float[][] predicts = booster.predict(testMat, true);
|
||||
|
||||
//eval
|
||||
IEvaluation eval = new EvalError();
|
||||
//error must be less than 0.1
|
||||
TestCase.assertTrue(eval.eval(predicts, testMat)<0.1f);
|
||||
}
|
||||
}
|
||||
102
java/xgboost4j/src/test/java/org/dmlc/xgboost4j/DMatrixTest.java
Normal file
102
java/xgboost4j/src/test/java/org/dmlc/xgboost4j/DMatrixTest.java
Normal file
@ -0,0 +1,102 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Random;
|
||||
import junit.framework.TestCase;
|
||||
import org.dmlc.xgboost4j.util.XGBoostError;
|
||||
import org.junit.Test;
|
||||
|
||||
/**
|
||||
* test cases for DMatrix
|
||||
* @author hzx
|
||||
*/
|
||||
public class DMatrixTest {
|
||||
|
||||
@Test
|
||||
public void testCreateFromFile() throws XGBoostError {
|
||||
//create DMatrix from file
|
||||
DMatrix dmat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
//get label
|
||||
float[] labels = dmat.getLabel();
|
||||
//check length
|
||||
TestCase.assertTrue(dmat.rowNum()==labels.length);
|
||||
//set weights
|
||||
float[] weights = Arrays.copyOf(labels, labels.length);
|
||||
dmat.setWeight(weights);
|
||||
float[] dweights = dmat.getWeight();
|
||||
TestCase.assertTrue(Arrays.equals(weights, dweights));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateFromCSR() throws XGBoostError {
|
||||
//create Matrix from csr format sparse Matrix and labels
|
||||
/**
|
||||
* sparse matrix
|
||||
* 1 0 2 3 0
|
||||
* 4 0 2 3 5
|
||||
* 3 1 2 5 0
|
||||
*/
|
||||
float[] data = new float[] {1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5};
|
||||
int[] colIndex = new int[] {0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3};
|
||||
long[] rowHeaders = new long[] {0, 3, 7, 11};
|
||||
DMatrix dmat1 = new DMatrix(rowHeaders, colIndex, data, DMatrix.SparseType.CSR);
|
||||
//check row num
|
||||
System.out.println(dmat1.rowNum());
|
||||
TestCase.assertTrue(dmat1.rowNum()==3);
|
||||
//test set label
|
||||
float[] label1 = new float[] {1, 0, 1};
|
||||
dmat1.setLabel(label1);
|
||||
float[] label2 = dmat1.getLabel();
|
||||
TestCase.assertTrue(Arrays.equals(label1, label2));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateFromDenseMatrix() throws XGBoostError {
|
||||
//create DMatrix from 10*5 dense matrix
|
||||
int nrow = 10;
|
||||
int ncol = 5;
|
||||
float[] data0 = new float[nrow*ncol];
|
||||
//put random nums
|
||||
Random random = new Random();
|
||||
for(int i=0; i<nrow*ncol; i++) {
|
||||
data0[i] = random.nextFloat();
|
||||
}
|
||||
|
||||
//create label
|
||||
float[] label0 = new float[nrow];
|
||||
for(int i=0; i<nrow; i++) {
|
||||
label0[i] = random.nextFloat();
|
||||
}
|
||||
|
||||
DMatrix dmat0 = new DMatrix(data0, nrow, ncol);
|
||||
dmat0.setLabel(label0);
|
||||
|
||||
//check
|
||||
TestCase.assertTrue(dmat0.rowNum()==10);
|
||||
TestCase.assertTrue(dmat0.getLabel().length==10);
|
||||
|
||||
//set weights for each instance
|
||||
float[] weights = new float[nrow];
|
||||
for(int i=0; i<nrow; i++) {
|
||||
weights[i] = random.nextFloat();
|
||||
}
|
||||
dmat0.setWeight(weights);
|
||||
|
||||
TestCase.assertTrue(Arrays.equals(weights, dmat0.getWeight()));
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user