refactor jni code and rename libxgboostjavawrapper.so to libxgboost4j.so

This commit is contained in:
yanqingmen
2015-12-19 22:26:40 +08:00
parent f378fac6a1
commit 1456585249
9 changed files with 202 additions and 341 deletions

View File

@@ -1,23 +0,0 @@
# xgboost4j
this is a java wrapper for xgboost (https://github.com/dmlc/xgboost)
the structure of this wrapper is almost the same as the official python wrapper.
core of this wrapper is two classes:
* DMatrix for handling data
* Booster: for train and predict
## usage:
simple examples could be found in test package:
* Simple Train Example: org.dmlc.xgboost4j.TrainExample.java
* Simple Predict Example: org.dmlc.xgboost4j.PredictExample.java
* Cross Validation Example: org.dmlc.xgboost4j.example.CVExample.java
## native library:
only 64-bit linux/windows is supported now, if you want to build native wrapper library yourself, please refer to
https://github.com/yanqingmen/xgboost-java, and put your native library to the "./src/main/resources/lib" folder and replace the originals. (either "libxgboostjavawrapper.so" for linux or "xgboostjavawrapper.dll" for windows)

View File

@@ -233,7 +233,7 @@ public final class Booster {
* @param predLeaf
* @return predict results
*/
private synchronized float[][] pred(DMatrix data, boolean outPutMargin, long treeLimit, boolean predLeaf) throws XGBoostError {
private synchronized float[][] pred(DMatrix data, boolean outPutMargin, int treeLimit, boolean predLeaf) throws XGBoostError {
int optionMask = 0;
if(outPutMargin) {
optionMask = 1;
@@ -284,7 +284,7 @@ public final class Booster {
* @return predict result
* @throws org.dmlc.xgboost4j.util.XGBoostError
*/
public float[][] predict(DMatrix data, boolean outPutMargin, long treeLimit) throws XGBoostError {
public float[][] predict(DMatrix data, boolean outPutMargin, int treeLimit) throws XGBoostError {
return pred(data, outPutMargin, treeLimit, false);
}
@@ -299,7 +299,7 @@ public final class Booster {
* @return predict result
* @throws org.dmlc.xgboost4j.util.XGBoostError
*/
public float[][] predict(DMatrix data , long treeLimit, boolean predLeaf) throws XGBoostError {
public float[][] predict(DMatrix data , int treeLimit, boolean predLeaf) throws XGBoostError {
return pred(data, false, treeLimit, predLeaf);
}

View File

@@ -31,7 +31,7 @@ public class Initializer {
static boolean initialized = false;
public static final String nativePath = "./lib";
public static final String nativeResourcePath = "/lib/";
public static final String[] libNames = new String[] {"xgboostjavawrapper"};
public static final String[] libNames = new String[] {"xgboost4j"};
public static synchronized void InitXgboost() throws IOException {
if(initialized == false) {

View File

@@ -41,10 +41,10 @@ public class XgboostJNI {
public final static native int XGBoosterUpdateOneIter(long handle, int iter, long dtrain);
public final static native int XGBoosterBoostOneIter(long handle, long dtrain, float[] grad, float[] hess);
public final static native int XGBoosterEvalOneIter(long handle, int iter, long[] dmats, String[] evnames, String[] eval_info);
public final static native int XGBoosterPredict(long handle, long dmat, int option_mask, long ntree_limit, float[][] predicts);
public final static native int XGBoosterPredict(long handle, long dmat, int option_mask, int ntree_limit, float[][] predicts);
public final static native int XGBoosterLoadModel(long handle, String fname);
public final static native int XGBoosterSaveModel(long handle, String fname);
public final static native int XGBoosterLoadModelFromBuffer(long handle, long buf, long len);
public final static native int XGBoosterGetModelRaw(long handle, String[] out_string);
public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats, String[][] out_strings);
}
}

View File

@@ -104,5 +104,39 @@ public class BoosterTest {
IEvaluation eval = new EvalError();
//error must be less than 0.1
TestCase.assertTrue(eval.eval(predicts, testMat)<0.1f);
//test dump model
}
/**
* test cross valiation
* @throws XGBoostError
*/
@Test
public void testCV() throws XGBoostError {
//load train mat
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
//set params
Map<String, Object> param= new HashMap<String, Object>() {
{
put("eta", 1.0);
put("max_depth", 3);
put("silent", 1);
put("nthread", 6);
put("objective", "binary:logistic");
put("gamma", 1.0);
put("eval_metric", "error");
}
};
//do 5-fold cross validation
int round = 2;
int nfold = 5;
//set additional eval_metrics
String[] metrics = null;
String[] evalHist = Trainer.crossValiation(param.entrySet(), trainMat, round, nfold, metrics, null, null);
}
}