update java wrapper for new fault handle API

This commit is contained in:
yanqingmen
2015-07-06 02:32:58 -07:00
parent 7755c00721
commit f73bcd427d
19 changed files with 558 additions and 329 deletions

View File

@@ -31,6 +31,7 @@ import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.demo.util.DataLoader;
import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.XgboostError;
/**
* a simple example of java wrapper for xgboost
@@ -52,7 +53,7 @@ public class BasicWalkThrough {
}
public static void main(String[] args) throws UnsupportedEncodingException, IOException {
public static void main(String[] args) throws UnsupportedEncodingException, IOException, XgboostError {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");

View File

@@ -23,13 +23,14 @@ import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.XgboostError;
/**
* example for start from a initial base prediction
* @author hzx
*/
public class BoostFromPrediction {
public static void main(String[] args) {
public static void main(String[] args) throws XgboostError {
System.out.println("start running example to start from a initial prediction");
// load file from text file, also binary buffer generated by xgboost4j

View File

@@ -19,13 +19,14 @@ import java.io.IOException;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.XgboostError;
/**
* an example of cross validation
* @author hzx
*/
public class CrossValidation {
public static void main(String[] args) throws IOException {
public static void main(String[] args) throws IOException, XgboostError {
//load train mat
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");

View File

@@ -19,12 +19,15 @@ import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.IEvaluation;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.IObjective;
import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.XgboostError;
/**
* an example user define objective and eval
@@ -40,6 +43,8 @@ public class CustomObjective {
* loglikelihoode loss obj function
*/
public static class LogRegObj implements IObjective {
private static final Log logger = LogFactory.getLog(LogRegObj.class);
/**
* simple sigmoid func
* @param input
@@ -66,7 +71,13 @@ public class CustomObjective {
public List<float[]> getGradient(float[][] predicts, DMatrix dtrain) {
int nrow = predicts.length;
List<float[]> gradients = new ArrayList<>();
float[] labels = dtrain.getLabel();
float[] labels;
try {
labels = dtrain.getLabel();
} catch (XgboostError ex) {
logger.error(ex);
return null;
}
float[] grad = new float[nrow];
float[] hess = new float[nrow];
@@ -93,6 +104,8 @@ public class CustomObjective {
* Take this in mind when you use the customization, and maybe you need write customized evaluation function
*/
public static class EvalError implements IEvaluation {
private static final Log logger = LogFactory.getLog(EvalError.class);
String evalMetric = "custom_error";
public EvalError() {
@@ -106,7 +119,13 @@ public class CustomObjective {
@Override
public float eval(float[][] predicts, DMatrix dmat) {
float error = 0f;
float[] labels = dmat.getLabel();
float[] labels;
try {
labels = dmat.getLabel();
} catch (XgboostError ex) {
logger.error(ex);
return -1f;
}
int nrow = predicts.length;
for(int i=0; i<nrow; i++) {
if(labels[i]==0f && predicts[i][0]>0) {
@@ -121,7 +140,7 @@ public class CustomObjective {
}
}
public static void main(String[] args) {
public static void main(String[] args) throws XgboostError {
//load train mat (svmlight format)
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
//load valid mat (svmlight format)

View File

@@ -23,13 +23,14 @@ import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.XgboostError;
/**
* simple example for using external memory version
* @author hzx
*/
public class ExternalMemory {
public static void main(String[] args) {
public static void main(String[] args) throws XgboostError {
//this is the only difference, add a # followed by a cache prefix name
//several cache file with the prefix will be generated
//currently only support convert from libsvm file

View File

@@ -24,6 +24,7 @@ import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.demo.util.CustomEval;
import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.XgboostError;
/**
* this is an example of fit generalized linear model in xgboost
@@ -31,7 +32,7 @@ import org.dmlc.xgboost4j.util.Trainer;
* @author hzx
*/
public class GeneralizedLinearModel {
public static void main(String[] args) {
public static void main(String[] args) throws XgboostError {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");

View File

@@ -25,13 +25,14 @@ import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.demo.util.CustomEval;
import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.XgboostError;
/**
* predict first ntree
* @author hzx
*/
public class PredictFirstNtree {
public static void main(String[] args) {
public static void main(String[] args) throws XgboostError {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");

View File

@@ -24,13 +24,14 @@ import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.XgboostError;
/**
* predict leaf indices
* @author hzx
*/
public class PredictLeafIndices {
public static void main(String[] args) {
public static void main(String[] args) throws XgboostError {
// load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");

View File

@@ -15,14 +15,18 @@
*/
package org.dmlc.xgboost4j.demo.util;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.IEvaluation;
import org.dmlc.xgboost4j.util.XgboostError;
/**
* a util evaluation class for examples
* @author hzx
*/
public class CustomEval implements IEvaluation {
private static final Log logger = LogFactory.getLog(CustomEval.class);
String evalMetric = "custom_error";
@@ -34,7 +38,13 @@ public class CustomEval implements IEvaluation {
@Override
public float eval(float[][] predicts, DMatrix dmat) {
float error = 0f;
float[] labels = dmat.getLabel();
float[] labels;
try {
labels = dmat.getLabel();
} catch (XgboostError ex) {
logger.error(ex);
return -1f;
}
int nrow = predicts.length;
for(int i=0; i<nrow; i++) {
if(labels[i]==0f && predicts[i][0]>0.5) {

View File

@@ -77,10 +77,8 @@ public class DataLoader {
reader.close();
in.close();
Float[] flabels = (Float[]) tlabels.toArray();
denseData.labels = ArrayUtils.toPrimitive(flabels);
Float[] fdata = (Float[]) tdata.toArray();
denseData.data = ArrayUtils.toPrimitive(fdata);
denseData.labels = ArrayUtils.toPrimitive(tlabels.toArray(new Float[tlabels.size()]));
denseData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata.size()]));
return denseData;
}