update java wrapper for new fault handle API
This commit is contained in:
@@ -31,6 +31,7 @@ import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.demo.util.DataLoader;
|
||||
import org.dmlc.xgboost4j.demo.util.Params;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
import org.dmlc.xgboost4j.util.XgboostError;
|
||||
|
||||
/**
|
||||
* a simple example of java wrapper for xgboost
|
||||
@@ -52,7 +53,7 @@ public class BasicWalkThrough {
|
||||
}
|
||||
|
||||
|
||||
public static void main(String[] args) throws UnsupportedEncodingException, IOException {
|
||||
public static void main(String[] args) throws UnsupportedEncodingException, IOException, XgboostError {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
@@ -23,13 +23,14 @@ import org.dmlc.xgboost4j.Booster;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.demo.util.Params;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
import org.dmlc.xgboost4j.util.XgboostError;
|
||||
|
||||
/**
|
||||
* example for start from a initial base prediction
|
||||
* @author hzx
|
||||
*/
|
||||
public class BoostFromPrediction {
|
||||
public static void main(String[] args) {
|
||||
public static void main(String[] args) throws XgboostError {
|
||||
System.out.println("start running example to start from a initial prediction");
|
||||
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
|
||||
@@ -19,13 +19,14 @@ import java.io.IOException;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
import org.dmlc.xgboost4j.demo.util.Params;
|
||||
import org.dmlc.xgboost4j.util.XgboostError;
|
||||
|
||||
/**
|
||||
* an example of cross validation
|
||||
* @author hzx
|
||||
*/
|
||||
public class CrossValidation {
|
||||
public static void main(String[] args) throws IOException {
|
||||
public static void main(String[] args) throws IOException, XgboostError {
|
||||
//load train mat
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
|
||||
|
||||
@@ -19,12 +19,15 @@ import java.util.AbstractMap;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.dmlc.xgboost4j.Booster;
|
||||
import org.dmlc.xgboost4j.IEvaluation;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.IObjective;
|
||||
import org.dmlc.xgboost4j.demo.util.Params;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
import org.dmlc.xgboost4j.util.XgboostError;
|
||||
|
||||
/**
|
||||
* an example user define objective and eval
|
||||
@@ -40,6 +43,8 @@ public class CustomObjective {
|
||||
* loglikelihoode loss obj function
|
||||
*/
|
||||
public static class LogRegObj implements IObjective {
|
||||
private static final Log logger = LogFactory.getLog(LogRegObj.class);
|
||||
|
||||
/**
|
||||
* simple sigmoid func
|
||||
* @param input
|
||||
@@ -66,7 +71,13 @@ public class CustomObjective {
|
||||
public List<float[]> getGradient(float[][] predicts, DMatrix dtrain) {
|
||||
int nrow = predicts.length;
|
||||
List<float[]> gradients = new ArrayList<>();
|
||||
float[] labels = dtrain.getLabel();
|
||||
float[] labels;
|
||||
try {
|
||||
labels = dtrain.getLabel();
|
||||
} catch (XgboostError ex) {
|
||||
logger.error(ex);
|
||||
return null;
|
||||
}
|
||||
float[] grad = new float[nrow];
|
||||
float[] hess = new float[nrow];
|
||||
|
||||
@@ -93,6 +104,8 @@ public class CustomObjective {
|
||||
* Take this in mind when you use the customization, and maybe you need write customized evaluation function
|
||||
*/
|
||||
public static class EvalError implements IEvaluation {
|
||||
private static final Log logger = LogFactory.getLog(EvalError.class);
|
||||
|
||||
String evalMetric = "custom_error";
|
||||
|
||||
public EvalError() {
|
||||
@@ -106,7 +119,13 @@ public class CustomObjective {
|
||||
@Override
|
||||
public float eval(float[][] predicts, DMatrix dmat) {
|
||||
float error = 0f;
|
||||
float[] labels = dmat.getLabel();
|
||||
float[] labels;
|
||||
try {
|
||||
labels = dmat.getLabel();
|
||||
} catch (XgboostError ex) {
|
||||
logger.error(ex);
|
||||
return -1f;
|
||||
}
|
||||
int nrow = predicts.length;
|
||||
for(int i=0; i<nrow; i++) {
|
||||
if(labels[i]==0f && predicts[i][0]>0) {
|
||||
@@ -121,7 +140,7 @@ public class CustomObjective {
|
||||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
public static void main(String[] args) throws XgboostError {
|
||||
//load train mat (svmlight format)
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
//load valid mat (svmlight format)
|
||||
|
||||
@@ -23,13 +23,14 @@ import org.dmlc.xgboost4j.Booster;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.demo.util.Params;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
import org.dmlc.xgboost4j.util.XgboostError;
|
||||
|
||||
/**
|
||||
* simple example for using external memory version
|
||||
* @author hzx
|
||||
*/
|
||||
public class ExternalMemory {
|
||||
public static void main(String[] args) {
|
||||
public static void main(String[] args) throws XgboostError {
|
||||
//this is the only difference, add a # followed by a cache prefix name
|
||||
//several cache file with the prefix will be generated
|
||||
//currently only support convert from libsvm file
|
||||
|
||||
@@ -24,6 +24,7 @@ import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.demo.util.CustomEval;
|
||||
import org.dmlc.xgboost4j.demo.util.Params;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
import org.dmlc.xgboost4j.util.XgboostError;
|
||||
|
||||
/**
|
||||
* this is an example of fit generalized linear model in xgboost
|
||||
@@ -31,7 +32,7 @@ import org.dmlc.xgboost4j.util.Trainer;
|
||||
* @author hzx
|
||||
*/
|
||||
public class GeneralizedLinearModel {
|
||||
public static void main(String[] args) {
|
||||
public static void main(String[] args) throws XgboostError {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
@@ -25,13 +25,14 @@ import org.dmlc.xgboost4j.util.Trainer;
|
||||
|
||||
import org.dmlc.xgboost4j.demo.util.CustomEval;
|
||||
import org.dmlc.xgboost4j.demo.util.Params;
|
||||
import org.dmlc.xgboost4j.util.XgboostError;
|
||||
|
||||
/**
|
||||
* predict first ntree
|
||||
* @author hzx
|
||||
*/
|
||||
public class PredictFirstNtree {
|
||||
public static void main(String[] args) {
|
||||
public static void main(String[] args) throws XgboostError {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
@@ -24,13 +24,14 @@ import org.dmlc.xgboost4j.Booster;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.util.Trainer;
|
||||
import org.dmlc.xgboost4j.demo.util.Params;
|
||||
import org.dmlc.xgboost4j.util.XgboostError;
|
||||
|
||||
/**
|
||||
* predict leaf indices
|
||||
* @author hzx
|
||||
*/
|
||||
public class PredictLeafIndices {
|
||||
public static void main(String[] args) {
|
||||
public static void main(String[] args) throws XgboostError {
|
||||
// load file from text file, also binary buffer generated by xgboost4j
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||
|
||||
@@ -15,14 +15,18 @@
|
||||
*/
|
||||
package org.dmlc.xgboost4j.demo.util;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.dmlc.xgboost4j.DMatrix;
|
||||
import org.dmlc.xgboost4j.IEvaluation;
|
||||
import org.dmlc.xgboost4j.util.XgboostError;
|
||||
|
||||
/**
|
||||
* a util evaluation class for examples
|
||||
* @author hzx
|
||||
*/
|
||||
public class CustomEval implements IEvaluation {
|
||||
private static final Log logger = LogFactory.getLog(CustomEval.class);
|
||||
|
||||
String evalMetric = "custom_error";
|
||||
|
||||
@@ -34,7 +38,13 @@ public class CustomEval implements IEvaluation {
|
||||
@Override
|
||||
public float eval(float[][] predicts, DMatrix dmat) {
|
||||
float error = 0f;
|
||||
float[] labels = dmat.getLabel();
|
||||
float[] labels;
|
||||
try {
|
||||
labels = dmat.getLabel();
|
||||
} catch (XgboostError ex) {
|
||||
logger.error(ex);
|
||||
return -1f;
|
||||
}
|
||||
int nrow = predicts.length;
|
||||
for(int i=0; i<nrow; i++) {
|
||||
if(labels[i]==0f && predicts[i][0]>0.5) {
|
||||
|
||||
@@ -77,10 +77,8 @@ public class DataLoader {
|
||||
reader.close();
|
||||
in.close();
|
||||
|
||||
Float[] flabels = (Float[]) tlabels.toArray();
|
||||
denseData.labels = ArrayUtils.toPrimitive(flabels);
|
||||
Float[] fdata = (Float[]) tdata.toArray();
|
||||
denseData.data = ArrayUtils.toPrimitive(fdata);
|
||||
denseData.labels = ArrayUtils.toPrimitive(tlabels.toArray(new Float[tlabels.size()]));
|
||||
denseData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata.size()]));
|
||||
|
||||
return denseData;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user