refactor jni code and rename libxgboostjavawrapper.so to libxgboost4j.so
This commit is contained in:
@@ -233,7 +233,7 @@ public final class Booster {
|
||||
* @param predLeaf
|
||||
* @return predict results
|
||||
*/
|
||||
private synchronized float[][] pred(DMatrix data, boolean outPutMargin, long treeLimit, boolean predLeaf) throws XGBoostError {
|
||||
private synchronized float[][] pred(DMatrix data, boolean outPutMargin, int treeLimit, boolean predLeaf) throws XGBoostError {
|
||||
int optionMask = 0;
|
||||
if(outPutMargin) {
|
||||
optionMask = 1;
|
||||
@@ -284,7 +284,7 @@ public final class Booster {
|
||||
* @return predict result
|
||||
* @throws org.dmlc.xgboost4j.util.XGBoostError
|
||||
*/
|
||||
public float[][] predict(DMatrix data, boolean outPutMargin, long treeLimit) throws XGBoostError {
|
||||
public float[][] predict(DMatrix data, boolean outPutMargin, int treeLimit) throws XGBoostError {
|
||||
return pred(data, outPutMargin, treeLimit, false);
|
||||
}
|
||||
|
||||
@@ -299,7 +299,7 @@ public final class Booster {
|
||||
* @return predict result
|
||||
* @throws org.dmlc.xgboost4j.util.XGBoostError
|
||||
*/
|
||||
public float[][] predict(DMatrix data , long treeLimit, boolean predLeaf) throws XGBoostError {
|
||||
public float[][] predict(DMatrix data , int treeLimit, boolean predLeaf) throws XGBoostError {
|
||||
return pred(data, false, treeLimit, predLeaf);
|
||||
}
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ public class Initializer {
|
||||
static boolean initialized = false;
|
||||
public static final String nativePath = "./lib";
|
||||
public static final String nativeResourcePath = "/lib/";
|
||||
public static final String[] libNames = new String[] {"xgboostjavawrapper"};
|
||||
public static final String[] libNames = new String[] {"xgboost4j"};
|
||||
|
||||
public static synchronized void InitXgboost() throws IOException {
|
||||
if(initialized == false) {
|
||||
|
||||
@@ -41,10 +41,10 @@ public class XgboostJNI {
|
||||
public final static native int XGBoosterUpdateOneIter(long handle, int iter, long dtrain);
|
||||
public final static native int XGBoosterBoostOneIter(long handle, long dtrain, float[] grad, float[] hess);
|
||||
public final static native int XGBoosterEvalOneIter(long handle, int iter, long[] dmats, String[] evnames, String[] eval_info);
|
||||
public final static native int XGBoosterPredict(long handle, long dmat, int option_mask, long ntree_limit, float[][] predicts);
|
||||
public final static native int XGBoosterPredict(long handle, long dmat, int option_mask, int ntree_limit, float[][] predicts);
|
||||
public final static native int XGBoosterLoadModel(long handle, String fname);
|
||||
public final static native int XGBoosterSaveModel(long handle, String fname);
|
||||
public final static native int XGBoosterLoadModelFromBuffer(long handle, long buf, long len);
|
||||
public final static native int XGBoosterGetModelRaw(long handle, String[] out_string);
|
||||
public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats, String[][] out_strings);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,5 +104,39 @@ public class BoosterTest {
|
||||
IEvaluation eval = new EvalError();
|
||||
//error must be less than 0.1
|
||||
TestCase.assertTrue(eval.eval(predicts, testMat)<0.1f);
|
||||
|
||||
//test dump model
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* test cross valiation
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
@Test
|
||||
public void testCV() throws XGBoostError {
|
||||
//load train mat
|
||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||
|
||||
//set params
|
||||
Map<String, Object> param= new HashMap<String, Object>() {
|
||||
{
|
||||
put("eta", 1.0);
|
||||
put("max_depth", 3);
|
||||
put("silent", 1);
|
||||
put("nthread", 6);
|
||||
put("objective", "binary:logistic");
|
||||
put("gamma", 1.0);
|
||||
put("eval_metric", "error");
|
||||
}
|
||||
};
|
||||
|
||||
//do 5-fold cross validation
|
||||
int round = 2;
|
||||
int nfold = 5;
|
||||
//set additional eval_metrics
|
||||
String[] metrics = null;
|
||||
|
||||
String[] evalHist = Trainer.crossValiation(param.entrySet(), trainMat, round, nfold, metrics, null, null);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user