From f73bcd427dcdd3602458f2b1cb873085f1fd768b Mon Sep 17 00:00:00 2001 From: yanqingmen Date: Mon, 6 Jul 2015 02:32:58 -0700 Subject: [PATCH] update java wrapper for new fault handle API --- .../dmlc/xgboost4j/demo/BasicWalkThrough.java | 3 +- .../xgboost4j/demo/BoostFromPrediction.java | 3 +- .../dmlc/xgboost4j/demo/CrossValidation.java | 3 +- .../dmlc/xgboost4j/demo/CustomObjective.java | 25 +- .../dmlc/xgboost4j/demo/ExternalMemory.java | 3 +- .../demo/GeneralizedLinearModel.java | 3 +- .../xgboost4j/demo/PredictFirstNtree.java | 3 +- .../xgboost4j/demo/PredictLeafIndices.java | 3 +- .../dmlc/xgboost4j/demo/util/CustomEval.java | 12 +- .../dmlc/xgboost4j/demo/util/DataLoader.java | 6 +- .../main/java/org/dmlc/xgboost4j/Booster.java | 112 +++--- .../main/java/org/dmlc/xgboost4j/DMatrix.java | 86 +++-- .../java/org/dmlc/xgboost4j/util/CVPack.java | 15 +- .../org/dmlc/xgboost4j/util/ErrorHandle.java | 50 +++ .../java/org/dmlc/xgboost4j/util/Trainer.java | 6 +- .../org/dmlc/xgboost4j/util/XgboostError.java | 26 ++ .../dmlc/xgboost4j/wrapper/XgboostJNI.java | 52 +-- java/xgboost4j_wrapper.cpp | 340 ++++++++++-------- java/xgboost4j_wrapper.h | 136 +++---- 19 files changed, 558 insertions(+), 329 deletions(-) create mode 100644 java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/ErrorHandle.java create mode 100644 java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/XgboostError.java diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BasicWalkThrough.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BasicWalkThrough.java index a0c7a3ae1..86ba49c48 100644 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BasicWalkThrough.java +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BasicWalkThrough.java @@ -31,6 +31,7 @@ import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.demo.util.DataLoader; import org.dmlc.xgboost4j.demo.util.Params; import org.dmlc.xgboost4j.util.Trainer; +import org.dmlc.xgboost4j.util.XgboostError; /** * a simple example of java wrapper for xgboost @@ -52,7 +53,7 @@ public class BasicWalkThrough { } - public static void main(String[] args) throws UnsupportedEncodingException, IOException { + public static void main(String[] args) throws UnsupportedEncodingException, IOException, XgboostError { // load file from text file, also binary buffer generated by xgboost4j DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BoostFromPrediction.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BoostFromPrediction.java index 733c49503..1113eef68 100644 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BoostFromPrediction.java +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/BoostFromPrediction.java @@ -23,13 +23,14 @@ import org.dmlc.xgboost4j.Booster; import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.demo.util.Params; import org.dmlc.xgboost4j.util.Trainer; +import org.dmlc.xgboost4j.util.XgboostError; /** * example for start from a initial base prediction * @author hzx */ public class BoostFromPrediction { - public static void main(String[] args) { + public static void main(String[] args) throws XgboostError { System.out.println("start running example to start from a initial prediction"); // load file from text file, also binary buffer generated by xgboost4j diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CrossValidation.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CrossValidation.java index 0c470bf17..ec5716700 100644 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CrossValidation.java +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CrossValidation.java @@ -19,13 +19,14 @@ import java.io.IOException; import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.util.Trainer; import org.dmlc.xgboost4j.demo.util.Params; +import org.dmlc.xgboost4j.util.XgboostError; /** * an example of cross validation * @author hzx */ public class CrossValidation { - public static void main(String[] args) throws IOException { + public static void main(String[] args) throws IOException, XgboostError { //load train mat DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CustomObjective.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CustomObjective.java index 03c9c4b52..4aaa053e0 100644 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CustomObjective.java +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/CustomObjective.java @@ -19,12 +19,15 @@ import java.util.AbstractMap; import java.util.ArrayList; import java.util.List; import java.util.Map; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.dmlc.xgboost4j.Booster; import org.dmlc.xgboost4j.IEvaluation; import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.IObjective; import org.dmlc.xgboost4j.demo.util.Params; import org.dmlc.xgboost4j.util.Trainer; +import org.dmlc.xgboost4j.util.XgboostError; /** * an example user define objective and eval @@ -40,6 +43,8 @@ public class CustomObjective { * loglikelihoode loss obj function */ public static class LogRegObj implements IObjective { + private static final Log logger = LogFactory.getLog(LogRegObj.class); + /** * simple sigmoid func * @param input @@ -66,7 +71,13 @@ public class CustomObjective { public List getGradient(float[][] predicts, DMatrix dtrain) { int nrow = predicts.length; List gradients = new ArrayList<>(); - float[] labels = dtrain.getLabel(); + float[] labels; + try { + labels = dtrain.getLabel(); + } catch (XgboostError ex) { + logger.error(ex); + return null; + } float[] grad = new float[nrow]; float[] hess = new float[nrow]; @@ -93,6 +104,8 @@ public class CustomObjective { * Take this in mind when you use the customization, and maybe you need write customized evaluation function */ public static class EvalError implements IEvaluation { + private static final Log logger = LogFactory.getLog(EvalError.class); + String evalMetric = "custom_error"; public EvalError() { @@ -106,7 +119,13 @@ public class CustomObjective { @Override public float eval(float[][] predicts, DMatrix dmat) { float error = 0f; - float[] labels = dmat.getLabel(); + float[] labels; + try { + labels = dmat.getLabel(); + } catch (XgboostError ex) { + logger.error(ex); + return -1f; + } int nrow = predicts.length; for(int i=0; i0) { @@ -121,7 +140,7 @@ public class CustomObjective { } } - public static void main(String[] args) { + public static void main(String[] args) throws XgboostError { //load train mat (svmlight format) DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); //load valid mat (svmlight format) diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/ExternalMemory.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/ExternalMemory.java index 6ac687289..e74e3e858 100644 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/ExternalMemory.java +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/ExternalMemory.java @@ -23,13 +23,14 @@ import org.dmlc.xgboost4j.Booster; import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.demo.util.Params; import org.dmlc.xgboost4j.util.Trainer; +import org.dmlc.xgboost4j.util.XgboostError; /** * simple example for using external memory version * @author hzx */ public class ExternalMemory { - public static void main(String[] args) { + public static void main(String[] args) throws XgboostError { //this is the only difference, add a # followed by a cache prefix name //several cache file with the prefix will be generated //currently only support convert from libsvm file diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/GeneralizedLinearModel.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/GeneralizedLinearModel.java index 2a20edbff..db3cd0e59 100644 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/GeneralizedLinearModel.java +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/GeneralizedLinearModel.java @@ -24,6 +24,7 @@ import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.demo.util.CustomEval; import org.dmlc.xgboost4j.demo.util.Params; import org.dmlc.xgboost4j.util.Trainer; +import org.dmlc.xgboost4j.util.XgboostError; /** * this is an example of fit generalized linear model in xgboost @@ -31,7 +32,7 @@ import org.dmlc.xgboost4j.util.Trainer; * @author hzx */ public class GeneralizedLinearModel { - public static void main(String[] args) { + public static void main(String[] args) throws XgboostError { // load file from text file, also binary buffer generated by xgboost4j DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictFirstNtree.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictFirstNtree.java index 8e3f3abfb..6bcf67f86 100644 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictFirstNtree.java +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictFirstNtree.java @@ -25,13 +25,14 @@ import org.dmlc.xgboost4j.util.Trainer; import org.dmlc.xgboost4j.demo.util.CustomEval; import org.dmlc.xgboost4j.demo.util.Params; +import org.dmlc.xgboost4j.util.XgboostError; /** * predict first ntree * @author hzx */ public class PredictFirstNtree { - public static void main(String[] args) { + public static void main(String[] args) throws XgboostError { // load file from text file, also binary buffer generated by xgboost4j DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictLeafIndices.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictLeafIndices.java index 697f40379..61026a6b8 100644 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictLeafIndices.java +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/PredictLeafIndices.java @@ -24,13 +24,14 @@ import org.dmlc.xgboost4j.Booster; import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.util.Trainer; import org.dmlc.xgboost4j.demo.util.Params; +import org.dmlc.xgboost4j.util.XgboostError; /** * predict leaf indices * @author hzx */ public class PredictLeafIndices { - public static void main(String[] args) { + public static void main(String[] args) throws XgboostError { // load file from text file, also binary buffer generated by xgboost4j DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/CustomEval.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/CustomEval.java index ad3a9124b..116c06ddf 100644 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/CustomEval.java +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/CustomEval.java @@ -15,14 +15,18 @@ */ package org.dmlc.xgboost4j.demo.util; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.IEvaluation; +import org.dmlc.xgboost4j.util.XgboostError; /** * a util evaluation class for examples * @author hzx */ public class CustomEval implements IEvaluation { + private static final Log logger = LogFactory.getLog(CustomEval.class); String evalMetric = "custom_error"; @@ -34,7 +38,13 @@ public class CustomEval implements IEvaluation { @Override public float eval(float[][] predicts, DMatrix dmat) { float error = 0f; - float[] labels = dmat.getLabel(); + float[] labels; + try { + labels = dmat.getLabel(); + } catch (XgboostError ex) { + logger.error(ex); + return -1f; + } int nrow = predicts.length; for(int i=0; i0.5) { diff --git a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/DataLoader.java b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/DataLoader.java index 0a020c761..9bad8b372 100644 --- a/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/DataLoader.java +++ b/java/xgboost4j-demo/src/main/java/org/dmlc/xgboost4j/demo/util/DataLoader.java @@ -77,10 +77,8 @@ public class DataLoader { reader.close(); in.close(); - Float[] flabels = (Float[]) tlabels.toArray(); - denseData.labels = ArrayUtils.toPrimitive(flabels); - Float[] fdata = (Float[]) tdata.toArray(); - denseData.data = ArrayUtils.toPrimitive(fdata); + denseData.labels = ArrayUtils.toPrimitive(tlabels.toArray(new Float[tlabels.size()])); + denseData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata.size()])); return denseData; } diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/Booster.java b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/Booster.java index c5d8b1006..0f296241b 100644 --- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/Booster.java +++ b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/Booster.java @@ -30,6 +30,8 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.dmlc.xgboost4j.util.Initializer; +import org.dmlc.xgboost4j.util.ErrorHandle; +import org.dmlc.xgboost4j.util.XgboostError; import org.dmlc.xgboost4j.wrapper.XgboostJNI; @@ -57,8 +59,9 @@ public final class Booster { * init Booster from dMatrixs * @param params parameters * @param dMatrixs DMatrix array + * @throws org.dmlc.xgboost4j.util.XgboostError */ - public Booster(Iterable> params, DMatrix[] dMatrixs) { + public Booster(Iterable> params, DMatrix[] dMatrixs) throws XgboostError { init(dMatrixs); setParam("seed","0"); setParams(params); @@ -70,9 +73,11 @@ public final class Booster { * load model from modelPath * @param params parameters * @param modelPath booster modelPath (model generated by booster.saveModel) + * @throws org.dmlc.xgboost4j.util.XgboostError */ - public Booster(Iterable> params, String modelPath) { - handle = XgboostJNI.XGBoosterCreate(new long[] {}); + public Booster(Iterable> params, String modelPath) throws XgboostError { + long[] out = new long[1]; + init(null); loadModel(modelPath); setParam("seed","0"); setParams(params); @@ -81,28 +86,33 @@ public final class Booster { - private void init(DMatrix[] dMatrixs) { + private void init(DMatrix[] dMatrixs) throws XgboostError { long[] handles = null; if(dMatrixs != null) { handles = dMatrixs2handles(dMatrixs); } - handle = XgboostJNI.XGBoosterCreate(handles); + long[] out = new long[1]; + ErrorHandle.checkCall(XgboostJNI.XGBoosterCreate(handles, out)); + + handle = out[0]; } /** * set parameter * @param key param name * @param value param value + * @throws org.dmlc.xgboost4j.util.XgboostError */ - public final void setParam(String key, String value) { - XgboostJNI.XGBoosterSetParam(handle, key, value); + public final void setParam(String key, String value) throws XgboostError { + ErrorHandle.checkCall(XgboostJNI.XGBoosterSetParam(handle, key, value)); } /** * set parameters * @param params parameters key-value map + * @throws org.dmlc.xgboost4j.util.XgboostError */ - public void setParams(Iterable> params) { + public void setParams(Iterable> params) throws XgboostError { if(params!=null) { for(Map.Entry entry : params) { setParam(entry.getKey(), entry.getValue().toString()); @@ -115,9 +125,10 @@ public final class Booster { * Update (one iteration) * @param dtrain training data * @param iter current iteration number + * @throws org.dmlc.xgboost4j.util.XgboostError */ - public void update(DMatrix dtrain, int iter) { - XgboostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()); + public void update(DMatrix dtrain, int iter) throws XgboostError { + ErrorHandle.checkCall(XgboostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle())); } /** @@ -125,8 +136,9 @@ public final class Booster { * @param dtrain training data * @param iter current iteration number * @param obj customized objective class + * @throws org.dmlc.xgboost4j.util.XgboostError */ - public void update(DMatrix dtrain, int iter, IObjective obj) { + public void update(DMatrix dtrain, int iter, IObjective obj) throws XgboostError { float[][] predicts = predict(dtrain, true); List gradients = obj.getGradient(predicts, dtrain); boost(dtrain, gradients.get(0), gradients.get(1)); @@ -137,12 +149,13 @@ public final class Booster { * @param dtrain training data * @param grad first order of gradient * @param hess seconde order of gradient + * @throws org.dmlc.xgboost4j.util.XgboostError */ - public void boost(DMatrix dtrain, float[] grad, float[] hess) { + public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XgboostError { if(grad.length != hess.length) { throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length, hess.length)); } - XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad, hess); + ErrorHandle.checkCall(XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad, hess)); } /** @@ -151,11 +164,13 @@ public final class Booster { * @param evalNames name for eval dmatrixs, used for check results * @param iter current eval iteration * @return eval information + * @throws org.dmlc.xgboost4j.util.XgboostError */ - public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) { + public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XgboostError { long[] handles = dMatrixs2handles(evalMatrixs); - String evalInfo = XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames); - return evalInfo; + String[] evalInfo = new String[1]; + ErrorHandle.checkCall(XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames, evalInfo)); + return evalInfo[0]; } /** @@ -165,8 +180,9 @@ public final class Booster { * @param iter * @param eval * @return eval information + * @throws org.dmlc.xgboost4j.util.XgboostError */ - public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter, IEvaluation eval) { + public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter, IEvaluation eval) throws XgboostError { String evalInfo = ""; for(int i=0; i getFeatureScore() { + public Map getFeatureScore() throws XgboostError { String[] modelInfos = getDumpInfo(false); Map featureScore = new HashMap<>(); for(String tree : modelInfos) { @@ -400,8 +431,9 @@ public final class Booster { * get importance of each feature * @param featureMap file to save dumped model info * @return featureMap key: feature index, value: feature importance score + * @throws org.dmlc.xgboost4j.util.XgboostError */ - public Map getFeatureScore(String featureMap) { + public Map getFeatureScore(String featureMap) throws XgboostError { String[] modelInfos = getDumpInfo(featureMap, false); Map featureScore = new HashMap<>(); for(String tree : modelInfos) { diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/DMatrix.java b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/DMatrix.java index ebeb80a46..b056cad09 100644 --- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/DMatrix.java +++ b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/DMatrix.java @@ -18,6 +18,8 @@ package org.dmlc.xgboost4j; import java.io.IOException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.dmlc.xgboost4j.util.ErrorHandle; +import org.dmlc.xgboost4j.util.XgboostError; import org.dmlc.xgboost4j.util.Initializer; import org.dmlc.xgboost4j.wrapper.XgboostJNI; @@ -50,9 +52,12 @@ public class DMatrix { /** * init DMatrix from file (svmlight format) * @param dataPath + * @throws org.dmlc.xgboost4j.util.XgboostError */ - public DMatrix(String dataPath) { - handle = XgboostJNI.XGDMatrixCreateFromFile(dataPath, 1); + public DMatrix(String dataPath) throws XgboostError { + long[] out = new long[1]; + ErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromFile(dataPath, 1, out)); + handle = out[0]; } /** @@ -61,17 +66,20 @@ public class DMatrix { * @param indices Indices (colIndexs for CSR or rowIndexs for CSC) * @param data non zero values (sequence by row for CSR or by col for CSC) * @param st sparse matrix type (CSR or CSC) + * @throws org.dmlc.xgboost4j.util.XgboostError */ - public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) { + public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XgboostError { + long[] out = new long[1]; if(st == SparseType.CSR) { - handle = XgboostJNI.XGDMatrixCreateFromCSR(headers, indices, data); + ErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromCSR(headers, indices, data, out)); } else if(st == SparseType.CSC) { - handle = XgboostJNI.XGDMatrixCreateFromCSC(headers, indices, data); + ErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromCSC(headers, indices, data, out)); } else { throw new UnknownError("unknow sparsetype"); } + handle = out[0]; } /** @@ -79,10 +87,13 @@ public class DMatrix { * @param data data values * @param nrow number of rows * @param ncol number of columns + * @throws org.dmlc.xgboost4j.util.XgboostError */ - public DMatrix(float[] data, int nrow, int ncol) { - handle = XgboostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f); - } + public DMatrix(float[] data, int nrow, int ncol) throws XgboostError { + long[] out = new long[1]; + ErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f, out)); + handle = out[0]; + } /** * used for DMatrix slice @@ -98,33 +109,36 @@ public class DMatrix { * set label of dmatrix * @param labels */ - public void setLabel(float[] labels) { - XgboostJNI.XGDMatrixSetFloatInfo(handle, "label", labels); + public void setLabel(float[] labels) throws XgboostError { + ErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "label", labels)); } /** * set weight of each instance * @param weights + * @throws org.dmlc.xgboost4j.util.XgboostError */ - public void setWeight(float[] weights) { - XgboostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights); + public void setWeight(float[] weights) throws XgboostError { + ErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights)); } /** * if specified, xgboost will start from this init margin * can be used to specify initial prediction to boost from * @param baseMargin + * @throws org.dmlc.xgboost4j.util.XgboostError */ - public void setBaseMargin(float[] baseMargin) { - XgboostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin); + public void setBaseMargin(float[] baseMargin) throws XgboostError { + ErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin)); } /** * if specified, xgboost will start from this init margin * can be used to specify initial prediction to boost from * @param baseMargin + * @throws org.dmlc.xgboost4j.util.XgboostError */ - public void setBaseMargin(float[][] baseMargin) { + public void setBaseMargin(float[][] baseMargin) throws XgboostError { float[] flattenMargin = flatten(baseMargin); setBaseMargin(flattenMargin); } @@ -132,42 +146,48 @@ public class DMatrix { /** * Set group sizes of DMatrix (used for ranking) * @param group + * @throws org.dmlc.xgboost4j.util.XgboostError */ - public void setGroup(int[] group) { - XgboostJNI.XGDMatrixSetGroup(handle, group); + public void setGroup(int[] group) throws XgboostError { + ErrorHandle.checkCall(XgboostJNI.XGDMatrixSetGroup(handle, group)); } - private float[] getFloatInfo(String field) { - float[] infos = XgboostJNI.XGDMatrixGetFloatInfo(handle, field); - return infos; + private float[] getFloatInfo(String field) throws XgboostError { + float[][] infos = new float[1][]; + ErrorHandle.checkCall(XgboostJNI.XGDMatrixGetFloatInfo(handle, field, infos)); + return infos[0]; } - private int[] getIntInfo(String field) { - int[] infos = XgboostJNI.XGDMatrixGetUIntInfo(handle, field); - return infos; + private int[] getIntInfo(String field) throws XgboostError { + int[][] infos = new int[1][]; + ErrorHandle.checkCall(XgboostJNI.XGDMatrixGetUIntInfo(handle, field, infos)); + return infos[0]; } /** * get label values * @return label + * @throws org.dmlc.xgboost4j.util.XgboostError */ - public float[] getLabel() { + public float[] getLabel() throws XgboostError { return getFloatInfo("label"); } /** * get weight of the DMatrix * @return weights + * @throws org.dmlc.xgboost4j.util.XgboostError */ - public float[] getWeight() { + public float[] getWeight() throws XgboostError { return getFloatInfo("weight"); } /** * get base margin of the DMatrix * @return base margin + * @throws org.dmlc.xgboost4j.util.XgboostError */ - public float[] getBaseMargin() { + public float[] getBaseMargin() throws XgboostError { return getFloatInfo("base_margin"); } @@ -175,9 +195,12 @@ public class DMatrix { * Slice the DMatrix and return a new DMatrix that only contains `rowIndex`. * @param rowIndex * @return sliced new DMatrix + * @throws org.dmlc.xgboost4j.util.XgboostError */ - public DMatrix slice(int[] rowIndex) { - long sHandle = XgboostJNI.XGDMatrixSliceDMatrix(handle, rowIndex); + public DMatrix slice(int[] rowIndex) throws XgboostError { + long[] out = new long[1]; + ErrorHandle.checkCall(XgboostJNI.XGDMatrixSliceDMatrix(handle, rowIndex, out)); + long sHandle = out[0]; DMatrix sMatrix = new DMatrix(sHandle); return sMatrix; } @@ -185,9 +208,12 @@ public class DMatrix { /** * get the row number of DMatrix * @return number of rows + * @throws org.dmlc.xgboost4j.util.XgboostError */ - public long rowNum() { - return XgboostJNI.XGDMatrixNumRow(handle); + public long rowNum() throws XgboostError { + long[] rowNum = new long[1]; + ErrorHandle.checkCall(XgboostJNI.XGDMatrixNumRow(handle,rowNum)); + return rowNum[0]; } /** diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/CVPack.java b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/CVPack.java index 3e67dc669..33be48b53 100644 --- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/CVPack.java +++ b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/CVPack.java @@ -37,8 +37,9 @@ public class CVPack { * @param dtrain train data * @param dtest test data * @param params parameters + * @throws org.dmlc.xgboost4j.util.XgboostError */ - public CVPack(DMatrix dtrain, DMatrix dtest, Iterable> params) { + public CVPack(DMatrix dtrain, DMatrix dtest, Iterable> params) throws XgboostError { dmats = new DMatrix[] {dtrain, dtest}; booster = new Booster(params, dmats); names = new String[] {"train", "test"}; @@ -49,8 +50,9 @@ public class CVPack { /** * update one iteration * @param iter iteration num + * @throws org.dmlc.xgboost4j.util.XgboostError */ - public void update(int iter) { + public void update(int iter) throws XgboostError { booster.update(dtrain, iter); } @@ -58,8 +60,9 @@ public class CVPack { * update one iteration * @param iter iteration num * @param obj customized objective + * @throws org.dmlc.xgboost4j.util.XgboostError */ - public void update(int iter, IObjective obj) { + public void update(int iter, IObjective obj) throws XgboostError { booster.update(dtrain, iter, obj); } @@ -67,8 +70,9 @@ public class CVPack { * evaluation * @param iter iteration num * @return + * @throws org.dmlc.xgboost4j.util.XgboostError */ - public String eval(int iter) { + public String eval(int iter) throws XgboostError { return booster.evalSet(dmats, names, iter); } @@ -77,8 +81,9 @@ public class CVPack { * @param iter iteration num * @param eval customized eval * @return + * @throws org.dmlc.xgboost4j.util.XgboostError */ - public String eval(int iter, IEvaluation eval) { + public String eval(int iter, IEvaluation eval) throws XgboostError { return booster.evalSet(dmats, names, iter, eval); } } diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/ErrorHandle.java b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/ErrorHandle.java new file mode 100644 index 000000000..5093eb1db --- /dev/null +++ b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/ErrorHandle.java @@ -0,0 +1,50 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package org.dmlc.xgboost4j.util; + +import java.io.IOException; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.dmlc.xgboost4j.wrapper.XgboostJNI; + +/** + * error handle for Xgboost + * @author hzx + */ +public class ErrorHandle { + private static final Log logger = LogFactory.getLog(ErrorHandle.class); + + //load native library + static { + try { + Initializer.InitXgboost(); + } catch (IOException ex) { + logger.error("load native library failed."); + logger.error(ex); + } + } + + /** + * check the return value of C API + * @param ret return valud of xgboostJNI C API call + * @throws org.dmlc.xgboost4j.util.XgboostError + */ + public static void checkCall(int ret) throws XgboostError { + if(ret != 0) { + throw new XgboostError(XgboostJNI.XGBGetLastError()); + } + } +} diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Trainer.java b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Trainer.java index 8a336b1a8..a81963da7 100644 --- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Trainer.java +++ b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/Trainer.java @@ -47,7 +47,7 @@ public class Trainer { * @return trained booster */ public static Booster train(Iterable> params, DMatrix dtrain, int round, - Iterable> watchs, IObjective obj, IEvaluation eval) { + Iterable> watchs, IObjective obj, IEvaluation eval) throws XgboostError { //collect eval matrixs String[] evalNames; @@ -112,7 +112,7 @@ public class Trainer { * @param eval customized evaluation (set to null if not used) * @return evaluation history */ - public static String[] crossValiation(Iterable> params, DMatrix data, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) { + public static String[] crossValiation(Iterable> params, DMatrix data, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) throws XgboostError { CVPack[] cvPacks = makeNFold(data, nfold, params, metrics); String[] evalHist = new String[round]; String[] results = new String[cvPacks.length]; @@ -149,7 +149,7 @@ public class Trainer { * @param evalMetrics Evaluation metrics * @return CV package array */ - public static CVPack[] makeNFold(DMatrix data, int nfold, Iterable> params, String[] evalMetrics) { + public static CVPack[] makeNFold(DMatrix data, int nfold, Iterable> params, String[] evalMetrics) throws XgboostError { List samples = genRandPermutationNums(0, (int) data.rowNum()); int step = samples.size()/nfold; int[] testSlice = new int[step]; diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/XgboostError.java b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/XgboostError.java new file mode 100644 index 000000000..8dabcee4b --- /dev/null +++ b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/util/XgboostError.java @@ -0,0 +1,26 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ +package org.dmlc.xgboost4j.util; + +/** + * custom error class for xgboost + * @author hzx + */ +public class XgboostError extends Exception{ + public XgboostError(String message) { + super(message); + } +} diff --git a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/wrapper/XgboostJNI.java b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/wrapper/XgboostJNI.java index 96a429c07..fe181347a 100644 --- a/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/wrapper/XgboostJNI.java +++ b/java/xgboost4j/src/main/java/org/dmlc/xgboost4j/wrapper/XgboostJNI.java @@ -17,32 +17,34 @@ package org.dmlc.xgboost4j.wrapper; /** * xgboost jni wrapper functions for xgboost_wrapper.h + * change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster * @author hzx */ public class XgboostJNI { - public final static native long XGDMatrixCreateFromFile(String fname, int silent); - public final static native long XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data); - public final static native long XGDMatrixCreateFromCSC(long[] colptr, int[] indices, float[] data); - public final static native long XGDMatrixCreateFromMat(float[] data, int nrow, int ncol, float missing); - public final static native long XGDMatrixSliceDMatrix(long handle, int[] idxset); - public final static native void XGDMatrixFree(long handle); - public final static native void XGDMatrixSaveBinary(long handle, String fname, int silent); - public final static native void XGDMatrixSetFloatInfo(long handle, String field, float[] array); - public final static native void XGDMatrixSetUIntInfo(long handle, String field, int[] array); - public final static native void XGDMatrixSetGroup(long handle, int[] group); - public final static native float[] XGDMatrixGetFloatInfo(long handle, String field); - public final static native int[] XGDMatrixGetUIntInfo(long handle, String filed); - public final static native long XGDMatrixNumRow(long handle); - public final static native long XGBoosterCreate(long[] handles); - public final static native void XGBoosterFree(long handle); - public final static native void XGBoosterSetParam(long handle, String name, String value); - public final static native void XGBoosterUpdateOneIter(long handle, int iter, long dtrain); - public final static native void XGBoosterBoostOneIter(long handle, long dtrain, float[] grad, float[] hess); - public final static native String XGBoosterEvalOneIter(long handle, int iter, long[] dmats, String[] evnames); - public final static native float[] XGBoosterPredict(long handle, long dmat, int option_mask, long ntree_limit); - public final static native void XGBoosterLoadModel(long handle, String fname); - public final static native void XGBoosterSaveModel(long handle, String fname); - public final static native void XGBoosterLoadModelFromBuffer(long handle, long buf, long len); - public final static native String XGBoosterGetModelRaw(long handle); - public final static native String[] XGBoosterDumpModel(long handle, String fmap, int with_stats); + public final static native String XGBGetLastError(); + public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out); + public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data, long[] out); + public final static native int XGDMatrixCreateFromCSC(long[] colptr, int[] indices, float[] data, long[] out); + public final static native int XGDMatrixCreateFromMat(float[] data, int nrow, int ncol, float missing, long[] out); + public final static native int XGDMatrixSliceDMatrix(long handle, int[] idxset, long[] out); + public final static native int XGDMatrixFree(long handle); + public final static native int XGDMatrixSaveBinary(long handle, String fname, int silent); + public final static native int XGDMatrixSetFloatInfo(long handle, String field, float[] array); + public final static native int XGDMatrixSetUIntInfo(long handle, String field, int[] array); + public final static native int XGDMatrixSetGroup(long handle, int[] group); + public final static native int XGDMatrixGetFloatInfo(long handle, String field, float[][] info); + public final static native int XGDMatrixGetUIntInfo(long handle, String filed, int[][] info); + public final static native int XGDMatrixNumRow(long handle, long[] row); + public final static native int XGBoosterCreate(long[] handles, long[] out); + public final static native int XGBoosterFree(long handle); + public final static native int XGBoosterSetParam(long handle, String name, String value); + public final static native int XGBoosterUpdateOneIter(long handle, int iter, long dtrain); + public final static native int XGBoosterBoostOneIter(long handle, long dtrain, float[] grad, float[] hess); + public final static native int XGBoosterEvalOneIter(long handle, int iter, long[] dmats, String[] evnames, String[] eval_info); + public final static native int XGBoosterPredict(long handle, long dmat, int option_mask, long ntree_limit, float[][] predicts); + public final static native int XGBoosterLoadModel(long handle, String fname); + public final static native int XGBoosterSaveModel(long handle, String fname); + public final static native int XGBoosterLoadModelFromBuffer(long handle, long buf, long len); + public final static native int XGBoosterGetModelRaw(long handle, String[] out_string); + public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats, String[][] out_strings); } diff --git a/java/xgboost4j_wrapper.cpp b/java/xgboost4j_wrapper.cpp index 55dc31bc8..f1e749982 100644 --- a/java/xgboost4j_wrapper.cpp +++ b/java/xgboost4j_wrapper.cpp @@ -16,21 +16,34 @@ #include "../wrapper/xgboost_wrapper.h" #include "xgboost4j_wrapper.h" -JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromFile - (JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent) { - jlong jresult = 0 ; +JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBGetLastError + (JNIEnv *jenv, jclass jcls) { + jstring jresult = 0 ; + char* result = 0; + result = (char *)XGBGetLastError(); + if (result) jresult = jenv->NewStringUTF((const char *)result); + return jresult; +} + +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromFile + (JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent, jlongArray jout) { + jint jresult = 0 ; char *fname = (char *) 0 ; int silent; - void *result = 0 ; - fname = 0; - if (jfname) { - fname = (char *)jenv->GetStringUTFChars(jfname, 0); - if (!fname) return 0; - } + void* result[1]; + unsigned long out[1]; + + fname = (char *)jenv->GetStringUTFChars(jfname, 0); + silent = (int)jsilent; - result = (void *)XGDMatrixCreateFromFile((char const *)fname, silent); - *(void **)&jresult = result; + jresult = (jint) XGDMatrixCreateFromFile((char const *)fname, silent, result); + + + *(void **)&out[0] = *result; + if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname); + + jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) out); return jresult; } @@ -39,12 +52,13 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCrea * Method: XGDMatrixCreateFromCSR * Signature: ([J[J[F)J */ -JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSR - (JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata) { - jlong jresult = 0 ; +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSR + (JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) { + jint jresult = 0 ; bst_ulong nindptr ; bst_ulong nelem; - void *result = 0 ; + void *result[1]; + unsigned long out[1]; jlong* indptr = jenv->GetLongArrayElements(jindptr, 0); jint* indices = jenv->GetIntArrayElements(jindices, 0); @@ -52,8 +66,9 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCrea nindptr = (bst_ulong)jenv->GetArrayLength(jindptr); nelem = (bst_ulong)jenv->GetArrayLength(jdata); - result = (void *)XGDMatrixCreateFromCSR((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem); - *(void **)&jresult = result; + jresult = (jint) XGDMatrixCreateFromCSR((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, result); + *(void **)&out[0] = *result; + jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) out); //release jenv->ReleaseLongArrayElements(jindptr, indptr, 0); @@ -68,12 +83,13 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCrea * Method: XGDMatrixCreateFromCSC * Signature: ([J[J[F)J */ -JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSC - (JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata) { - jlong jresult = 0 ; +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSC + (JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) { + jint jresult = 0; bst_ulong nindptr ; bst_ulong nelem; - void *result = 0 ; + void *result[1]; + unsigned long out[1]; jlong* indptr = jenv->GetLongArrayElements(jindptr, NULL); jint* indices = jenv->GetIntArrayElements(jindices, 0); @@ -81,8 +97,9 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCrea nindptr = (bst_ulong)jenv->GetArrayLength(jindptr); nelem = (bst_ulong)jenv->GetArrayLength(jdata); - result = (void *)XGDMatrixCreateFromCSC((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem); - *(void **)&jresult = result; + jresult = (jint) XGDMatrixCreateFromCSC((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, result); + *(void **)&out[0] = *result; + jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) out); //release jenv->ReleaseLongArrayElements(jindptr, indptr, 0); @@ -97,21 +114,24 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCrea * Method: XGDMatrixCreateFromMat * Signature: ([FIIF)J */ -JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromMat - (JNIEnv *jenv, jclass jcls, jfloatArray jdata, jint jnrow, jint jncol, jfloat jmiss) { - jlong jresult = 0 ; +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromMat + (JNIEnv *jenv, jclass jcls, jfloatArray jdata, jint jnrow, jint jncol, jfloat jmiss, jlongArray jout) { + jint jresult = 0 ; bst_ulong nrow ; bst_ulong ncol ; float miss ; - void *result = 0 ; + void *result[1]; + unsigned long out[1]; jfloat* data = jenv->GetFloatArrayElements(jdata, 0); nrow = (bst_ulong)jnrow; ncol = (bst_ulong)jncol; miss = (float)jmiss; - result = (void *)XGDMatrixCreateFromMat((float const *)data, nrow, ncol, miss); - *(void **)&jresult = result; + + jresult = (jint) XGDMatrixCreateFromMat((float const *)data, nrow, ncol, miss, result); + *(void **)&out[0] = *result; + jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) out); //release jenv->ReleaseFloatArrayElements(jdata, data, 0); @@ -124,19 +144,21 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCrea * Method: XGDMatrixSliceDMatrix * Signature: (J[I)J */ -JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSliceDMatrix - (JNIEnv *jenv, jclass jcls, jlong jhandle, jintArray jindexset) { - jlong jresult = 0 ; +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSliceDMatrix + (JNIEnv *jenv, jclass jcls, jlong jhandle, jintArray jindexset, jlongArray jout) { + jint jresult = 0 ; void *handle = (void *) 0 ; bst_ulong len; - void *result = 0 ; + void *result[1]; + unsigned long out[1]; jint* indexset = jenv->GetIntArrayElements(jindexset, 0); handle = *(void **)&jhandle; len = (bst_ulong)jenv->GetArrayLength(jindexset); - result = (void *)XGDMatrixSliceDMatrix(handle, (int const *)indexset, len); - *(void **)&jresult = result; + jresult = (jint) XGDMatrixSliceDMatrix(handle, (int const *)indexset, len, result); + *(void **)&out[0] = *result; + jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) out); //release jenv->ReleaseIntArrayElements(jindexset, indexset, 0); @@ -149,11 +171,13 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSlic * Method: XGDMatrixFree * Signature: (J)V */ -JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixFree +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixFree (JNIEnv *jenv, jclass jcls, jlong jhandle) { + jint jresult = 0; void *handle = (void *) 0 ; handle = *(void **)&jhandle; - XGDMatrixFree(handle); + jresult = (jint) XGDMatrixFree(handle); + return jresult; } /* @@ -161,20 +185,21 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixFree * Method: XGDMatrixSaveBinary * Signature: (JLjava/lang/String;I)V */ -JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSaveBinary +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSaveBinary (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname, jint jsilent) { + jint jresult = 0; void *handle = (void *) 0 ; char *fname = (char *) 0 ; int silent ; handle = *(void **)&jhandle; fname = 0; - if (jfname) { - fname = (char *)jenv->GetStringUTFChars(jfname, 0); - if (!fname) return ; - } + fname = (char *)jenv->GetStringUTFChars(jfname, 0); + silent = (int)jsilent; - XGDMatrixSaveBinary(handle, (char const *)fname, silent); + jresult = (jint) XGDMatrixSaveBinary(handle, (char const *)fname, silent); if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname); + + return jresult; } /* @@ -182,27 +207,28 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSaveB * Method: XGDMatrixSetFloatInfo * Signature: (JLjava/lang/String;[F)V */ -JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetFloatInfo +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetFloatInfo (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jfloatArray jarray) { + jint jresult = 0; void *handle = (void *) 0 ; char *field = (char *) 0 ; bst_ulong len; handle = *(void **)&jhandle; - field = 0; - if (jfield) { - field = (char *)jenv->GetStringUTFChars(jfield, 0); - if (!field) return ; - } + + field = (char *)jenv->GetStringUTFChars(jfield, 0); + jfloat* array = jenv->GetFloatArrayElements(jarray, NULL); len = (bst_ulong)jenv->GetArrayLength(jarray); - XGDMatrixSetFloatInfo(handle, (char const *)field, (float const *)array, len); + jresult = (jint) XGDMatrixSetFloatInfo(handle, (char const *)field, (float const *)array, len); //release if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field); jenv->ReleaseFloatArrayElements(jarray, array, 0); + + return jresult; } /* @@ -210,25 +236,26 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetFl * Method: XGDMatrixSetUIntInfo * Signature: (JLjava/lang/String;[I)V */ -JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetUIntInfo +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetUIntInfo (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jintArray jarray) { + jint jresult = 0; void *handle = (void *) 0 ; char *field = (char *) 0 ; bst_ulong len ; handle = *(void **)&jhandle; field = 0; - if (jfield) { - field = (char *)jenv->GetStringUTFChars(jfield, 0); - if (!field) return ; - } + field = (char *)jenv->GetStringUTFChars(jfield, 0); + jint* array = jenv->GetIntArrayElements(jarray, NULL); len = (bst_ulong)jenv->GetArrayLength(jarray); - XGDMatrixSetUIntInfo(handle, (char const *)field, (unsigned int const *)array, len); + jresult = (jint) XGDMatrixSetUIntInfo(handle, (char const *)field, (unsigned int const *)array, len); //release if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field); jenv->ReleaseIntArrayElements(jarray, array, 0); + + return jresult; } /* @@ -236,8 +263,9 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetUI * Method: XGDMatrixSetGroup * Signature: (J[I)V */ -JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGroup +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGroup (JNIEnv * jenv, jclass jcls, jlong jhandle, jintArray jarray) { + jint jresult = 0; void *handle = (void *) 0 ; bst_ulong len ; @@ -245,11 +273,12 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGr jint* array = jenv->GetIntArrayElements(jarray, NULL); len = (bst_ulong)jenv->GetArrayLength(jarray); - XGDMatrixSetGroup(handle, (unsigned int const *)array, len); + jresult = (jint) XGDMatrixSetGroup(handle, (unsigned int const *)array, len); //release jenv->ReleaseIntArrayElements(jarray, array, 0); - + + return jresult; } /* @@ -257,28 +286,31 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGr * Method: XGDMatrixGetFloatInfo * Signature: (JLjava/lang/String;)[F */ -JNIEXPORT jfloatArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetFloatInfo - (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield) { +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetFloatInfo + (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) { + jint jresult = 0; void *handle = (void *) 0 ; char *field = (char *) 0 ; bst_ulong len[1]; *len = 0; - float *result = 0 ; + float *result[1]; - handle = *(void **)&jhandle; + handle = *(void **)&jhandle; field = 0; if (jfield) { field = (char *)jenv->GetStringUTFChars(jfield, 0); if (!field) return 0; } - result = (float *)XGDMatrixGetFloatInfo((void const *)handle, (char const *)field, len); + jresult = (jint) XGDMatrixGetFloatInfo(handle, (char const *)field, len, (const float **) result); if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field); jsize jlen = (jsize)*len; - jfloatArray jresult = jenv->NewFloatArray(jlen); - jenv->SetFloatArrayRegion(jresult, 0, jlen, (jfloat *)result); + jfloatArray jarray = jenv->NewFloatArray(jlen); + jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) *result); + jenv->SetObjectArrayElement(jout, 0, (jobject) jarray); + return jresult; } @@ -287,28 +319,26 @@ JNIEXPORT jfloatArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatr * Method: XGDMatrixGetUIntInfo * Signature: (JLjava/lang/String;)[I */ -JNIEXPORT jintArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetUIntInfo - (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield) { +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetUIntInfo + (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) { + jint jresult = 0; void *handle = (void *) 0 ; char *field = (char *) 0 ; bst_ulong len[1]; *len = 0; - unsigned int *result = 0 ; + unsigned int *result[1]; handle = *(void **)&jhandle; - field = 0; - if (jfield) { - field = (char *)jenv->GetStringUTFChars(jfield, 0); - if (!field) return 0; - } + field = (char *)jenv->GetStringUTFChars(jfield, 0); - result = (unsigned int *)XGDMatrixGetUIntInfo((void const *)handle, (char const *)field, len); + jresult = (jint) XGDMatrixGetUIntInfo(handle, (char const *)field, len, (const unsigned int **) result); if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field); jsize jlen = (jsize)*len; - jintArray jresult = jenv->NewIntArray(jlen); - jenv->SetIntArrayRegion(jresult, 0, jlen, (jint *)result); + jintArray jarray = jenv->NewIntArray(jlen); + jenv->SetIntArrayRegion(jarray, 0, jlen, (jint *) *result); + jenv->SetObjectArrayElement(jout, 0, jarray); return jresult; } @@ -317,14 +347,14 @@ JNIEXPORT jintArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrix * Method: XGDMatrixNumRow * Signature: (J)J */ -JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixNumRow - (JNIEnv *jenv, jclass jcls, jlong jhandle) { - jlong jresult = 0 ; +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixNumRow + (JNIEnv *jenv, jclass jcls, jlong jhandle, jlongArray jout) { + jint jresult = 0 ; void *handle = (void *) 0 ; - bst_ulong result; + bst_ulong result[1]; handle = *(void **)&jhandle; - result = (bst_ulong)XGDMatrixNumRow((void const *)handle); - jresult = (jlong)result; + jresult = (jint) XGDMatrixNumRow(handle, result); + jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) result); return jresult; } @@ -333,13 +363,14 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixNumR * Method: XGBoosterCreate * Signature: ([J)J */ -JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCreate - (JNIEnv *jenv, jclass jcls, jlongArray jhandles) { - jlong jresult = 0 ; +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCreate + (JNIEnv *jenv, jclass jcls, jlongArray jhandles, jlongArray jout) { + jint jresult = 0; void **handles = 0; bst_ulong len = 0; - void *result = 0 ; + void *result[1]; jlong* cjhandles = 0; + unsigned long out[1]; if(jhandles) { len = (bst_ulong)jenv->GetArrayLength(jhandles); @@ -351,7 +382,7 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCrea } } - result = (void *)XGBoosterCreate(handles, len); + jresult = (jint) XGBoosterCreate(handles, len, result); //release if(jhandles) { @@ -359,7 +390,9 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCrea jenv->ReleaseLongArrayElements(jhandles, cjhandles, 0); } - *(void **)&jresult = result; + *(void **)&out[0] = *result; + jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) out); + return jresult; } @@ -368,11 +401,11 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCrea * Method: XGBoosterFree * Signature: (J)V */ -JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterFree +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterFree (JNIEnv *jenv, jclass jcls, jlong jhandle) { void *handle = (void *) 0 ; handle = *(void **)&jhandle; - XGBoosterFree(handle); + return (jint) XGBoosterFree(handle); } @@ -381,27 +414,22 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterFree * Method: XGBoosterSetParam * Signature: (JLjava/lang/String;Ljava/lang/String;)V */ -JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSetParam +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSetParam (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jname, jstring jvalue) { + jint jresult = -1; void *handle = (void *) 0 ; char *name = (char *) 0 ; char *value = (char *) 0 ; handle = *(void **)&jhandle; - name = 0; - if (jname) { - name = (char *)jenv->GetStringUTFChars(jname, 0); - if (!name) return ; - } - - value = 0; - if (jvalue) { - value = (char *)jenv->GetStringUTFChars(jvalue, 0); - if (!value) return ; - } - XGBoosterSetParam(handle, (char const *)name, (char const *)value); + name = (char *)jenv->GetStringUTFChars(jname, 0); + value = (char *)jenv->GetStringUTFChars(jvalue, 0); + + jresult = (jint) XGBoosterSetParam(handle, (char const *)name, (char const *)value); if (name) jenv->ReleaseStringUTFChars(jname, (const char *)name); if (value) jenv->ReleaseStringUTFChars(jvalue, (const char *)value); + + return jresult; } /* @@ -409,7 +437,7 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSetPa * Method: XGBoosterUpdateOneIter * Signature: (JIJ)V */ -JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdateOneIter +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdateOneIter (JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlong jdtrain) { void *handle = (void *) 0 ; int iter ; @@ -417,7 +445,7 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdat handle = *(void **)&jhandle; iter = (int)jiter; dtrain = *(void **)&jdtrain; - XGBoosterUpdateOneIter(handle, iter, dtrain); + return (jint) XGBoosterUpdateOneIter(handle, iter, dtrain); } /* @@ -425,8 +453,9 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdat * Method: XGBoosterBoostOneIter * Signature: (JJ[F[F)V */ -JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoostOneIter +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoostOneIter (JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jfloatArray jgrad, jfloatArray jhess) { + jint jresult = 0; void *handle = (void *) 0 ; void *dtrain = (void *) 0 ; bst_ulong len ; @@ -436,11 +465,13 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoost jfloat* grad = jenv->GetFloatArrayElements(jgrad, 0); jfloat* hess = jenv->GetFloatArrayElements(jhess, 0); len = (bst_ulong)jenv->GetArrayLength(jgrad); - XGBoosterBoostOneIter(handle, dtrain, grad, hess, len); + jresult = (jint) XGBoosterBoostOneIter(handle, dtrain, grad, hess, len); //release jenv->ReleaseFloatArrayElements(jgrad, grad, 0); jenv->ReleaseFloatArrayElements(jhess, hess, 0); + + return jresult; } /* @@ -448,15 +479,15 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoost * Method: XGBoosterEvalOneIter * Signature: (JI[J[Ljava/lang/String;)Ljava/lang/String; */ -JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEvalOneIter - (JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlongArray jdmats, jobjectArray jevnames) { - jstring jresult = 0 ; +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEvalOneIter + (JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlongArray jdmats, jobjectArray jevnames, jobjectArray jout) { + jint jresult = 0 ; void *handle = (void *) 0 ; int iter ; void **dmats = 0; char **evnames = 0; bst_ulong len ; - char *result = 0 ; + char *result[1]; handle = *(void **)&jhandle; iter = (int)jiter; @@ -480,7 +511,7 @@ JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEv evnames[i] = (char *)jenv->GetStringUTFChars(jevname, 0); } - result = (char *)XGBoosterEvalOneIter(handle, iter, dmats, (char const *(*))evnames, len); + jresult = (jint) XGBoosterEvalOneIter(handle, iter, dmats, (char const *(*))evnames, len, (const char **) result); if(len > 0) { delete[] dmats; @@ -493,7 +524,9 @@ JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEv jenv->ReleaseLongArrayElements(jdmats, cjdmats, 0); } - if (result) jresult = jenv->NewStringUTF((const char *)result); + jstring jinfo = 0; + if (*result) jinfo = jenv->NewStringUTF((const char *) *result); + jenv->SetObjectArrayElement(jout, 0, jinfo); return jresult; } @@ -503,26 +536,29 @@ JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEv * Method: XGBoosterPredict * Signature: (JJIJ)[F */ -JNIEXPORT jfloatArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterPredict - (JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdmat, jint joption_mask, jlong jntree_limit) { +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterPredict + (JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdmat, jint joption_mask, jlong jntree_limit, jobjectArray jout) { + jint jresult = 0; void *handle = (void *) 0 ; void *dmat = (void *) 0 ; int option_mask ; unsigned int ntree_limit ; bst_ulong len[1]; *len = 0; - float *result = 0 ; + float *result[1]; handle = *(void **)&jhandle; dmat = *(void **)&jdmat; option_mask = (int)joption_mask; ntree_limit = (unsigned int)jntree_limit; - result = (float *)XGBoosterPredict(handle, dmat, option_mask, ntree_limit, len); + jresult = (jint) XGBoosterPredict(handle, dmat, option_mask, ntree_limit, len, (const float **) result); jsize jlen = (jsize)*len; - jfloatArray jresult = jenv->NewFloatArray(jlen); - jenv->SetFloatArrayRegion(jresult, 0, jlen, (jfloat *)result); + jfloatArray jarray = jenv->NewFloatArray(jlen); + jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) *result); + jenv->SetObjectArrayElement(jout, 0, jarray); + return jresult; } @@ -531,18 +567,20 @@ JNIEXPORT jfloatArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoost * Method: XGBoosterLoadModel * Signature: (JLjava/lang/String;)V */ -JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModel +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModel (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) { + jint jresult = 0; void *handle = (void *) 0 ; char *fname = (char *) 0 ; handle = *(void **)&jhandle; - fname = 0; - if (jfname) { - fname = (char *)jenv->GetStringUTFChars(jfname, 0); - if (!fname) return ; - } - XGBoosterLoadModel(handle,(char const *)fname); + + fname = (char *)jenv->GetStringUTFChars(jfname, 0); + + + jresult = (jint) XGBoosterLoadModel(handle,(char const *)fname); if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname); + + return jresult; } /* @@ -550,18 +588,19 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadM * Method: XGBoosterSaveModel * Signature: (JLjava/lang/String;)V */ -JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSaveModel +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSaveModel (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) { + jint jresult = 0; void *handle = (void *) 0 ; char *fname = (char *) 0 ; handle = *(void **)&jhandle; fname = 0; - if (jfname) { - fname = (char *)jenv->GetStringUTFChars(jfname, 0); - if (!fname) return ; - } - XGBoosterSaveModel(handle, (char const *)fname); + fname = (char *)jenv->GetStringUTFChars(jfname, 0); + + jresult = (jint) XGBoosterSaveModel(handle, (char const *)fname); if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname); + + return jresult; } /* @@ -569,7 +608,7 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSaveM * Method: XGBoosterLoadModelFromBuffer * Signature: (JJJ)V */ -JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModelFromBuffer +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModelFromBuffer (JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jbuf, jlong jlen) { void *handle = (void *) 0 ; void *buf = (void *) 0 ; @@ -577,7 +616,7 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadM handle = *(void **)&jhandle; buf = *(void **)&jbuf; len = (bst_ulong)jlen; - XGBoosterLoadModelFromBuffer(handle, (void const *)buf, len); + return (jint) XGBoosterLoadModelFromBuffer(handle, (void const *)buf, len); } /* @@ -585,17 +624,21 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadM * Method: XGBoosterGetModelRaw * Signature: (J)Ljava/lang/String; */ -JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGetModelRaw - (JNIEnv * jenv, jclass jcls, jlong jhandle) { - jstring jresult = 0 ; +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGetModelRaw + (JNIEnv * jenv, jclass jcls, jlong jhandle, jobjectArray jout) { + jint jresult = 0 ; + jstring jinfo = 0; void *handle = (void *) 0 ; bst_ulong len[1]; *len = 0; - char *result = 0 ; + char *result[1]; handle = *(void **)&jhandle; - result = (char *)XGBoosterGetModelRaw(handle, len); - if (result) jresult = jenv->NewStringUTF((const char *)result); + jresult = (jint)XGBoosterGetModelRaw(handle, len, (const char **) result); + if (*result){ + jinfo = jenv->NewStringUTF((const char *) *result); + jenv->SetObjectArrayElement(jout, 0, jinfo); + } return jresult; } @@ -604,15 +647,16 @@ JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGe * Method: XGBoosterDumpModel * Signature: (JLjava/lang/String;I)[Ljava/lang/String; */ -JNIEXPORT jobjectArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterDumpModel - (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfmap, jint jwith_stats) { +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterDumpModel + (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfmap, jint jwith_stats, jobjectArray jout) { + jint jresult = 0; void *handle = (void *) 0 ; char *fmap = (char *) 0 ; int with_stats ; bst_ulong len[1]; *len = 0; - char **result = 0 ; + char **result[1]; handle = *(void **)&jhandle; fmap = 0; if (jfmap) { @@ -621,14 +665,16 @@ JNIEXPORT jobjectArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoos } with_stats = (int)jwith_stats; - result = (char **)XGBoosterDumpModel(handle, (char const *)fmap, with_stats, len); + jresult = (jint) XGBoosterDumpModel(handle, (const char *)fmap, with_stats, len, (const char ***) result); jsize jlen = (jsize)*len; - jobjectArray jresult = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF("")); + jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF("")); for(int i=0 ; iSetObjectArrayElement(jresult, i, jenv->NewStringUTF((const char*)result[i])); + jenv->SetObjectArrayElement(jinfos, i, jenv->NewStringUTF((const char*) result[0][i])); } + jenv->SetObjectArrayElement(jout, 0, jinfos); if (fmap) jenv->ReleaseStringUTFChars(jfmap, (const char *)fmap); + return jresult; } \ No newline at end of file diff --git a/java/xgboost4j_wrapper.h b/java/xgboost4j_wrapper.h index d13b86f8c..93764ef53 100644 --- a/java/xgboost4j_wrapper.h +++ b/java/xgboost4j_wrapper.h @@ -9,203 +9,211 @@ extern "C" { #endif /* * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI - * Method: XGDMatrixCreateFromFile - * Signature: (Ljava/lang/String;I)J + * Method: XGBGetLastError + * Signature: ()Ljava/lang/String; */ -JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromFile - (JNIEnv *, jclass, jstring, jint); +JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBGetLastError + (JNIEnv *, jclass); + +/* + * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI + * Method: XGDMatrixCreateFromFile + * Signature: (Ljava/lang/String;I[J)I + */ +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromFile + (JNIEnv *, jclass, jstring, jint, jlongArray); /* * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Method: XGDMatrixCreateFromCSR - * Signature: ([J[J[F)J + * Signature: ([J[I[F[J)I */ -JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSR - (JNIEnv *, jclass, jlongArray, jintArray, jfloatArray); +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSR + (JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jlongArray); /* * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Method: XGDMatrixCreateFromCSC - * Signature: ([J[J[F)J + * Signature: ([J[I[F[J)I */ -JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSC - (JNIEnv *, jclass, jlongArray, jintArray, jfloatArray); +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSC + (JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jlongArray); /* * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Method: XGDMatrixCreateFromMat - * Signature: ([FIIF)J + * Signature: ([FIIF[J)I */ -JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromMat - (JNIEnv *, jclass, jfloatArray, jint, jint, jfloat); +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromMat + (JNIEnv *, jclass, jfloatArray, jint, jint, jfloat, jlongArray); /* * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Method: XGDMatrixSliceDMatrix - * Signature: (J[I)J + * Signature: (J[I[J)I */ -JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSliceDMatrix - (JNIEnv *, jclass, jlong, jintArray); +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSliceDMatrix + (JNIEnv *, jclass, jlong, jintArray, jlongArray); /* * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Method: XGDMatrixFree - * Signature: (J)V + * Signature: (J)I */ -JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixFree +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixFree (JNIEnv *, jclass, jlong); /* * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Method: XGDMatrixSaveBinary - * Signature: (JLjava/lang/String;I)V + * Signature: (JLjava/lang/String;I)I */ -JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSaveBinary +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSaveBinary (JNIEnv *, jclass, jlong, jstring, jint); /* * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Method: XGDMatrixSetFloatInfo - * Signature: (JLjava/lang/String;[F)V + * Signature: (JLjava/lang/String;[F)I */ -JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetFloatInfo +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetFloatInfo (JNIEnv *, jclass, jlong, jstring, jfloatArray); /* * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Method: XGDMatrixSetUIntInfo - * Signature: (JLjava/lang/String;[I)V + * Signature: (JLjava/lang/String;[I)I */ -JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetUIntInfo +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetUIntInfo (JNIEnv *, jclass, jlong, jstring, jintArray); /* * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Method: XGDMatrixSetGroup - * Signature: (J[I)V + * Signature: (J[I)I */ -JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGroup +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGroup (JNIEnv *, jclass, jlong, jintArray); /* * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Method: XGDMatrixGetFloatInfo - * Signature: (JLjava/lang/String;)[F + * Signature: (JLjava/lang/String;[[F)I */ -JNIEXPORT jfloatArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetFloatInfo - (JNIEnv *, jclass, jlong, jstring); +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetFloatInfo + (JNIEnv *, jclass, jlong, jstring, jobjectArray); /* * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Method: XGDMatrixGetUIntInfo - * Signature: (JLjava/lang/String;)[I + * Signature: (JLjava/lang/String;[[I)I */ -JNIEXPORT jintArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetUIntInfo - (JNIEnv *, jclass, jlong, jstring); +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetUIntInfo + (JNIEnv *, jclass, jlong, jstring, jobjectArray); /* * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Method: XGDMatrixNumRow - * Signature: (J)J + * Signature: (J[J)I */ -JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixNumRow - (JNIEnv *, jclass, jlong); +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixNumRow + (JNIEnv *, jclass, jlong, jlongArray); /* * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Method: XGBoosterCreate - * Signature: ([J)J + * Signature: ([J[J)I */ -JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCreate - (JNIEnv *, jclass, jlongArray); +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCreate + (JNIEnv *, jclass, jlongArray, jlongArray); /* * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Method: XGBoosterFree - * Signature: (J)V + * Signature: (J)I */ -JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterFree +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterFree (JNIEnv *, jclass, jlong); /* * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Method: XGBoosterSetParam - * Signature: (JLjava/lang/String;Ljava/lang/String;)V + * Signature: (JLjava/lang/String;Ljava/lang/String;)I */ -JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSetParam +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSetParam (JNIEnv *, jclass, jlong, jstring, jstring); /* * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Method: XGBoosterUpdateOneIter - * Signature: (JIJ)V + * Signature: (JIJ)I */ -JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdateOneIter +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdateOneIter (JNIEnv *, jclass, jlong, jint, jlong); /* * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Method: XGBoosterBoostOneIter - * Signature: (JJ[F[F)V + * Signature: (JJ[F[F)I */ -JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoostOneIter +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoostOneIter (JNIEnv *, jclass, jlong, jlong, jfloatArray, jfloatArray); /* * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Method: XGBoosterEvalOneIter - * Signature: (JI[J[Ljava/lang/String;)Ljava/lang/String; + * Signature: (JI[J[Ljava/lang/String;[Ljava/lang/String;)I */ -JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEvalOneIter - (JNIEnv *, jclass, jlong, jint, jlongArray, jobjectArray); +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEvalOneIter + (JNIEnv *, jclass, jlong, jint, jlongArray, jobjectArray, jobjectArray); /* * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Method: XGBoosterPredict - * Signature: (JJIJ)[F + * Signature: (JJIJ[[F)I */ -JNIEXPORT jfloatArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterPredict - (JNIEnv *, jclass, jlong, jlong, jint, jlong); +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterPredict + (JNIEnv *, jclass, jlong, jlong, jint, jlong, jobjectArray); /* * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Method: XGBoosterLoadModel - * Signature: (JLjava/lang/String;)V + * Signature: (JLjava/lang/String;)I */ -JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModel +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModel (JNIEnv *, jclass, jlong, jstring); /* * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Method: XGBoosterSaveModel - * Signature: (JLjava/lang/String;)V + * Signature: (JLjava/lang/String;)I */ -JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSaveModel +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSaveModel (JNIEnv *, jclass, jlong, jstring); /* * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Method: XGBoosterLoadModelFromBuffer - * Signature: (JJJ)V + * Signature: (JJJ)I */ -JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModelFromBuffer +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModelFromBuffer (JNIEnv *, jclass, jlong, jlong, jlong); /* * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Method: XGBoosterGetModelRaw - * Signature: (J)Ljava/lang/String; + * Signature: (J[Ljava/lang/String;)I */ -JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGetModelRaw - (JNIEnv *, jclass, jlong); +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGetModelRaw + (JNIEnv *, jclass, jlong, jobjectArray); /* * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Method: XGBoosterDumpModel - * Signature: (JLjava/lang/String;I)[Ljava/lang/String; + * Signature: (JLjava/lang/String;I[[Ljava/lang/String;)I */ -JNIEXPORT jobjectArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterDumpModel - (JNIEnv *, jclass, jlong, jstring, jint); +JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterDumpModel + (JNIEnv *, jclass, jlong, jstring, jint, jobjectArray); #ifdef __cplusplus }