re-structure Java API, add Scala API and consolidate the names of Java/Scala API

This commit is contained in:
CodingCat
2015-12-22 03:29:20 -06:00
parent fc4c88fceb
commit 3b246c2420
56 changed files with 2902 additions and 2384 deletions

View File

@@ -0,0 +1,139 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j;
import junit.framework.TestCase;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.Test;
import java.util.*;
import java.util.Map.Entry;
/**
* test cases for Booster
*
* @author hzx
*/
public class BoosterImplTest {
public static class EvalError implements IEvaluation {
private static final Log logger = LogFactory.getLog(EvalError.class);
String evalMetric = "custom_error";
public EvalError() {
}
@Override
public String getMetric() {
return evalMetric;
}
@Override
public float eval(float[][] predicts, org.dmlc.xgboost4j.DMatrix dmat) {
float error = 0f;
float[] labels;
try {
labels = dmat.getLabel();
} catch (XGBoostError ex) {
logger.error(ex);
return -1f;
}
int nrow = predicts.length;
for (int i = 0; i < nrow; i++) {
if (labels[i] == 0f && predicts[i][0] > 0) {
error++;
} else if (labels[i] == 1f && predicts[i][0] <= 0) {
error++;
}
}
return error / labels.length;
}
}
@Test
public void testBoosterBasic() throws XGBoostError {
org.dmlc.xgboost4j.DMatrix trainMat = new org.dmlc.xgboost4j.DMatrix("../../demo/data/agaricus.txt.train");
org.dmlc.xgboost4j.DMatrix testMat = new org.dmlc.xgboost4j.DMatrix("../../demo/data/agaricus.txt.test");
//set params
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("eta", 1.0);
put("max_depth", 2);
put("silent", 1);
put("objective", "binary:logistic");
}
};
//set watchList
HashMap<String, org.dmlc.xgboost4j.DMatrix> watches = new HashMap<>();
watches.put("train", trainMat);
watches.put("test", testMat);
//set round
int round = 2;
//train a boost model
Booster booster = XGBoost.train(paramMap, trainMat, round, watches, null, null);
//predict raw output
float[][] predicts = booster.predict(testMat, true);
//eval
IEvaluation eval = new EvalError();
//error must be less than 0.1
TestCase.assertTrue(eval.eval(predicts, testMat) < 0.1f);
//test dump model
}
/**
* test cross valiation
*
* @throws XGBoostError
*/
@Test
public void testCV() throws XGBoostError {
//load train mat
org.dmlc.xgboost4j.DMatrix trainMat = new org.dmlc.xgboost4j.DMatrix("../../demo/data/agaricus.txt.train");
//set params
Map<String, Object> param = new HashMap<String, Object>() {
{
put("eta", 1.0);
put("max_depth", 3);
put("silent", 1);
put("nthread", 6);
put("objective", "binary:logistic");
put("gamma", 1.0);
put("eval_metric", "error");
}
};
//do 5-fold cross validation
int round = 2;
int nfold = 5;
//set additional eval_metrics
String[] metrics = null;
String[] evalHist = XGBoost.crossValiation(param, trainMat, round, nfold, metrics,
null, null);
}
}

View File

@@ -0,0 +1,104 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j;
import junit.framework.TestCase;
import org.dmlc.xgboost4j.*;
import org.junit.Test;
import java.util.Arrays;
import java.util.Random;
/**
* test cases for DMatrix
*
* @author hzx
*/
public class DMatrixTest {
@Test
public void testCreateFromFile() throws XGBoostError {
//create DMatrix from file
org.dmlc.xgboost4j.DMatrix dmat = new org.dmlc.xgboost4j.DMatrix("../../demo/data/agaricus.txt.test");
//get label
float[] labels = dmat.getLabel();
//check length
TestCase.assertTrue(dmat.rowNum() == labels.length);
//set weights
float[] weights = Arrays.copyOf(labels, labels.length);
dmat.setWeight(weights);
float[] dweights = dmat.getWeight();
TestCase.assertTrue(Arrays.equals(weights, dweights));
}
@Test
public void testCreateFromCSR() throws XGBoostError {
//create Matrix from csr format sparse Matrix and labels
/**
* sparse matrix
* 1 0 2 3 0
* 4 0 2 3 5
* 3 1 2 5 0
*/
float[] data = new float[]{1, 2, 3, 4, 2, 3, 5, 3, 1, 2, 5};
int[] colIndex = new int[]{0, 2, 3, 0, 2, 3, 4, 0, 1, 2, 3};
long[] rowHeaders = new long[]{0, 3, 7, 11};
org.dmlc.xgboost4j.DMatrix dmat1 = new org.dmlc.xgboost4j.DMatrix(rowHeaders, colIndex, data, org.dmlc.xgboost4j.DMatrix.SparseType.CSR);
//check row num
System.out.println(dmat1.rowNum());
TestCase.assertTrue(dmat1.rowNum() == 3);
//test set label
float[] label1 = new float[]{1, 0, 1};
dmat1.setLabel(label1);
float[] label2 = dmat1.getLabel();
TestCase.assertTrue(Arrays.equals(label1, label2));
}
@Test
public void testCreateFromDenseMatrix() throws XGBoostError {
//create DMatrix from 10*5 dense matrix
int nrow = 10;
int ncol = 5;
float[] data0 = new float[nrow * ncol];
//put random nums
Random random = new Random();
for (int i = 0; i < nrow * ncol; i++) {
data0[i] = random.nextFloat();
}
//create label
float[] label0 = new float[nrow];
for (int i = 0; i < nrow; i++) {
label0[i] = random.nextFloat();
}
org.dmlc.xgboost4j.DMatrix dmat0 = new org.dmlc.xgboost4j.DMatrix(data0, nrow, ncol);
dmat0.setLabel(label0);
//check
TestCase.assertTrue(dmat0.rowNum() == 10);
TestCase.assertTrue(dmat0.getLabel().length == 10);
//set weights for each instance
float[] weights = new float[nrow];
for (int i = 0; i < nrow; i++) {
weights[i] = random.nextFloat();
}
dmat0.setWeight(weights);
TestCase.assertTrue(Arrays.equals(weights, dmat0.getWeight()));
}
}