update java wrapper for new fault handle API

This commit is contained in:
yanqingmen 2015-07-06 02:32:58 -07:00
parent 7755c00721
commit f73bcd427d
19 changed files with 558 additions and 329 deletions

View File

@ -31,6 +31,7 @@ import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.demo.util.DataLoader; import org.dmlc.xgboost4j.demo.util.DataLoader;
import org.dmlc.xgboost4j.demo.util.Params; import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.Trainer; import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.XgboostError;
/** /**
* a simple example of java wrapper for xgboost * a simple example of java wrapper for xgboost
@ -52,7 +53,7 @@ public class BasicWalkThrough {
} }
public static void main(String[] args) throws UnsupportedEncodingException, IOException { public static void main(String[] args) throws UnsupportedEncodingException, IOException, XgboostError {
// load file from text file, also binary buffer generated by xgboost4j // load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");

View File

@ -23,13 +23,14 @@ import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.demo.util.Params; import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.Trainer; import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.XgboostError;
/** /**
* example for start from a initial base prediction * example for start from a initial base prediction
* @author hzx * @author hzx
*/ */
public class BoostFromPrediction { public class BoostFromPrediction {
public static void main(String[] args) { public static void main(String[] args) throws XgboostError {
System.out.println("start running example to start from a initial prediction"); System.out.println("start running example to start from a initial prediction");
// load file from text file, also binary buffer generated by xgboost4j // load file from text file, also binary buffer generated by xgboost4j

View File

@ -19,13 +19,14 @@ import java.io.IOException;
import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Trainer; import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.demo.util.Params; import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.XgboostError;
/** /**
* an example of cross validation * an example of cross validation
* @author hzx * @author hzx
*/ */
public class CrossValidation { public class CrossValidation {
public static void main(String[] args) throws IOException { public static void main(String[] args) throws IOException, XgboostError {
//load train mat //load train mat
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");

View File

@ -19,12 +19,15 @@ import java.util.AbstractMap;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.dmlc.xgboost4j.Booster; import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.IEvaluation; import org.dmlc.xgboost4j.IEvaluation;
import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.IObjective; import org.dmlc.xgboost4j.IObjective;
import org.dmlc.xgboost4j.demo.util.Params; import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.Trainer; import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.XgboostError;
/** /**
* an example user define objective and eval * an example user define objective and eval
@ -40,6 +43,8 @@ public class CustomObjective {
* loglikelihoode loss obj function * loglikelihoode loss obj function
*/ */
public static class LogRegObj implements IObjective { public static class LogRegObj implements IObjective {
private static final Log logger = LogFactory.getLog(LogRegObj.class);
/** /**
* simple sigmoid func * simple sigmoid func
* @param input * @param input
@ -66,7 +71,13 @@ public class CustomObjective {
public List<float[]> getGradient(float[][] predicts, DMatrix dtrain) { public List<float[]> getGradient(float[][] predicts, DMatrix dtrain) {
int nrow = predicts.length; int nrow = predicts.length;
List<float[]> gradients = new ArrayList<>(); List<float[]> gradients = new ArrayList<>();
float[] labels = dtrain.getLabel(); float[] labels;
try {
labels = dtrain.getLabel();
} catch (XgboostError ex) {
logger.error(ex);
return null;
}
float[] grad = new float[nrow]; float[] grad = new float[nrow];
float[] hess = new float[nrow]; float[] hess = new float[nrow];
@ -93,6 +104,8 @@ public class CustomObjective {
* Take this in mind when you use the customization, and maybe you need write customized evaluation function * Take this in mind when you use the customization, and maybe you need write customized evaluation function
*/ */
public static class EvalError implements IEvaluation { public static class EvalError implements IEvaluation {
private static final Log logger = LogFactory.getLog(EvalError.class);
String evalMetric = "custom_error"; String evalMetric = "custom_error";
public EvalError() { public EvalError() {
@ -106,7 +119,13 @@ public class CustomObjective {
@Override @Override
public float eval(float[][] predicts, DMatrix dmat) { public float eval(float[][] predicts, DMatrix dmat) {
float error = 0f; float error = 0f;
float[] labels = dmat.getLabel(); float[] labels;
try {
labels = dmat.getLabel();
} catch (XgboostError ex) {
logger.error(ex);
return -1f;
}
int nrow = predicts.length; int nrow = predicts.length;
for(int i=0; i<nrow; i++) { for(int i=0; i<nrow; i++) {
if(labels[i]==0f && predicts[i][0]>0) { if(labels[i]==0f && predicts[i][0]>0) {
@ -121,7 +140,7 @@ public class CustomObjective {
} }
} }
public static void main(String[] args) { public static void main(String[] args) throws XgboostError {
//load train mat (svmlight format) //load train mat (svmlight format)
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
//load valid mat (svmlight format) //load valid mat (svmlight format)

View File

@ -23,13 +23,14 @@ import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.demo.util.Params; import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.Trainer; import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.XgboostError;
/** /**
* simple example for using external memory version * simple example for using external memory version
* @author hzx * @author hzx
*/ */
public class ExternalMemory { public class ExternalMemory {
public static void main(String[] args) { public static void main(String[] args) throws XgboostError {
//this is the only difference, add a # followed by a cache prefix name //this is the only difference, add a # followed by a cache prefix name
//several cache file with the prefix will be generated //several cache file with the prefix will be generated
//currently only support convert from libsvm file //currently only support convert from libsvm file

View File

@ -24,6 +24,7 @@ import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.demo.util.CustomEval; import org.dmlc.xgboost4j.demo.util.CustomEval;
import org.dmlc.xgboost4j.demo.util.Params; import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.Trainer; import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.util.XgboostError;
/** /**
* this is an example of fit generalized linear model in xgboost * this is an example of fit generalized linear model in xgboost
@ -31,7 +32,7 @@ import org.dmlc.xgboost4j.util.Trainer;
* @author hzx * @author hzx
*/ */
public class GeneralizedLinearModel { public class GeneralizedLinearModel {
public static void main(String[] args) { public static void main(String[] args) throws XgboostError {
// load file from text file, also binary buffer generated by xgboost4j // load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");

View File

@ -25,13 +25,14 @@ import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.demo.util.CustomEval; import org.dmlc.xgboost4j.demo.util.CustomEval;
import org.dmlc.xgboost4j.demo.util.Params; import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.XgboostError;
/** /**
* predict first ntree * predict first ntree
* @author hzx * @author hzx
*/ */
public class PredictFirstNtree { public class PredictFirstNtree {
public static void main(String[] args) { public static void main(String[] args) throws XgboostError {
// load file from text file, also binary buffer generated by xgboost4j // load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");

View File

@ -24,13 +24,14 @@ import org.dmlc.xgboost4j.Booster;
import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.util.Trainer; import org.dmlc.xgboost4j.util.Trainer;
import org.dmlc.xgboost4j.demo.util.Params; import org.dmlc.xgboost4j.demo.util.Params;
import org.dmlc.xgboost4j.util.XgboostError;
/** /**
* predict leaf indices * predict leaf indices
* @author hzx * @author hzx
*/ */
public class PredictLeafIndices { public class PredictLeafIndices {
public static void main(String[] args) { public static void main(String[] args) throws XgboostError {
// load file from text file, also binary buffer generated by xgboost4j // load file from text file, also binary buffer generated by xgboost4j
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train"); DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test"); DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");

View File

@ -15,14 +15,18 @@
*/ */
package org.dmlc.xgboost4j.demo.util; package org.dmlc.xgboost4j.demo.util;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.dmlc.xgboost4j.DMatrix; import org.dmlc.xgboost4j.DMatrix;
import org.dmlc.xgboost4j.IEvaluation; import org.dmlc.xgboost4j.IEvaluation;
import org.dmlc.xgboost4j.util.XgboostError;
/** /**
* a util evaluation class for examples * a util evaluation class for examples
* @author hzx * @author hzx
*/ */
public class CustomEval implements IEvaluation { public class CustomEval implements IEvaluation {
private static final Log logger = LogFactory.getLog(CustomEval.class);
String evalMetric = "custom_error"; String evalMetric = "custom_error";
@ -34,7 +38,13 @@ public class CustomEval implements IEvaluation {
@Override @Override
public float eval(float[][] predicts, DMatrix dmat) { public float eval(float[][] predicts, DMatrix dmat) {
float error = 0f; float error = 0f;
float[] labels = dmat.getLabel(); float[] labels;
try {
labels = dmat.getLabel();
} catch (XgboostError ex) {
logger.error(ex);
return -1f;
}
int nrow = predicts.length; int nrow = predicts.length;
for(int i=0; i<nrow; i++) { for(int i=0; i<nrow; i++) {
if(labels[i]==0f && predicts[i][0]>0.5) { if(labels[i]==0f && predicts[i][0]>0.5) {

View File

@ -77,10 +77,8 @@ public class DataLoader {
reader.close(); reader.close();
in.close(); in.close();
Float[] flabels = (Float[]) tlabels.toArray(); denseData.labels = ArrayUtils.toPrimitive(tlabels.toArray(new Float[tlabels.size()]));
denseData.labels = ArrayUtils.toPrimitive(flabels); denseData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata.size()]));
Float[] fdata = (Float[]) tdata.toArray();
denseData.data = ArrayUtils.toPrimitive(fdata);
return denseData; return denseData;
} }

View File

@ -30,6 +30,8 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.dmlc.xgboost4j.util.Initializer; import org.dmlc.xgboost4j.util.Initializer;
import org.dmlc.xgboost4j.util.ErrorHandle;
import org.dmlc.xgboost4j.util.XgboostError;
import org.dmlc.xgboost4j.wrapper.XgboostJNI; import org.dmlc.xgboost4j.wrapper.XgboostJNI;
@ -57,8 +59,9 @@ public final class Booster {
* init Booster from dMatrixs * init Booster from dMatrixs
* @param params parameters * @param params parameters
* @param dMatrixs DMatrix array * @param dMatrixs DMatrix array
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public Booster(Iterable<Entry<String, Object>> params, DMatrix[] dMatrixs) { public Booster(Iterable<Entry<String, Object>> params, DMatrix[] dMatrixs) throws XgboostError {
init(dMatrixs); init(dMatrixs);
setParam("seed","0"); setParam("seed","0");
setParams(params); setParams(params);
@ -70,9 +73,11 @@ public final class Booster {
* load model from modelPath * load model from modelPath
* @param params parameters * @param params parameters
* @param modelPath booster modelPath (model generated by booster.saveModel) * @param modelPath booster modelPath (model generated by booster.saveModel)
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public Booster(Iterable<Entry<String, Object>> params, String modelPath) { public Booster(Iterable<Entry<String, Object>> params, String modelPath) throws XgboostError {
handle = XgboostJNI.XGBoosterCreate(new long[] {}); long[] out = new long[1];
init(null);
loadModel(modelPath); loadModel(modelPath);
setParam("seed","0"); setParam("seed","0");
setParams(params); setParams(params);
@ -81,28 +86,33 @@ public final class Booster {
private void init(DMatrix[] dMatrixs) { private void init(DMatrix[] dMatrixs) throws XgboostError {
long[] handles = null; long[] handles = null;
if(dMatrixs != null) { if(dMatrixs != null) {
handles = dMatrixs2handles(dMatrixs); handles = dMatrixs2handles(dMatrixs);
} }
handle = XgboostJNI.XGBoosterCreate(handles); long[] out = new long[1];
ErrorHandle.checkCall(XgboostJNI.XGBoosterCreate(handles, out));
handle = out[0];
} }
/** /**
* set parameter * set parameter
* @param key param name * @param key param name
* @param value param value * @param value param value
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public final void setParam(String key, String value) { public final void setParam(String key, String value) throws XgboostError {
XgboostJNI.XGBoosterSetParam(handle, key, value); ErrorHandle.checkCall(XgboostJNI.XGBoosterSetParam(handle, key, value));
} }
/** /**
* set parameters * set parameters
* @param params parameters key-value map * @param params parameters key-value map
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public void setParams(Iterable<Entry<String, Object>> params) { public void setParams(Iterable<Entry<String, Object>> params) throws XgboostError {
if(params!=null) { if(params!=null) {
for(Map.Entry<String, Object> entry : params) { for(Map.Entry<String, Object> entry : params) {
setParam(entry.getKey(), entry.getValue().toString()); setParam(entry.getKey(), entry.getValue().toString());
@ -115,9 +125,10 @@ public final class Booster {
* Update (one iteration) * Update (one iteration)
* @param dtrain training data * @param dtrain training data
* @param iter current iteration number * @param iter current iteration number
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public void update(DMatrix dtrain, int iter) { public void update(DMatrix dtrain, int iter) throws XgboostError {
XgboostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()); ErrorHandle.checkCall(XgboostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
} }
/** /**
@ -125,8 +136,9 @@ public final class Booster {
* @param dtrain training data * @param dtrain training data
* @param iter current iteration number * @param iter current iteration number
* @param obj customized objective class * @param obj customized objective class
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public void update(DMatrix dtrain, int iter, IObjective obj) { public void update(DMatrix dtrain, int iter, IObjective obj) throws XgboostError {
float[][] predicts = predict(dtrain, true); float[][] predicts = predict(dtrain, true);
List<float[]> gradients = obj.getGradient(predicts, dtrain); List<float[]> gradients = obj.getGradient(predicts, dtrain);
boost(dtrain, gradients.get(0), gradients.get(1)); boost(dtrain, gradients.get(0), gradients.get(1));
@ -137,12 +149,13 @@ public final class Booster {
* @param dtrain training data * @param dtrain training data
* @param grad first order of gradient * @param grad first order of gradient
* @param hess seconde order of gradient * @param hess seconde order of gradient
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public void boost(DMatrix dtrain, float[] grad, float[] hess) { public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XgboostError {
if(grad.length != hess.length) { if(grad.length != hess.length) {
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length, hess.length)); throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length, hess.length));
} }
XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad, hess); ErrorHandle.checkCall(XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad, hess));
} }
/** /**
@ -151,11 +164,13 @@ public final class Booster {
* @param evalNames name for eval dmatrixs, used for check results * @param evalNames name for eval dmatrixs, used for check results
* @param iter current eval iteration * @param iter current eval iteration
* @return eval information * @return eval information
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) { public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XgboostError {
long[] handles = dMatrixs2handles(evalMatrixs); long[] handles = dMatrixs2handles(evalMatrixs);
String evalInfo = XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames); String[] evalInfo = new String[1];
return evalInfo; ErrorHandle.checkCall(XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames, evalInfo));
return evalInfo[0];
} }
/** /**
@ -165,8 +180,9 @@ public final class Booster {
* @param iter * @param iter
* @param eval * @param eval
* @return eval information * @return eval information
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter, IEvaluation eval) { public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter, IEvaluation eval) throws XgboostError {
String evalInfo = ""; String evalInfo = "";
for(int i=0; i<evalNames.length; i++) { for(int i=0; i<evalNames.length; i++) {
String evalName = evalNames[i]; String evalName = evalNames[i];
@ -184,10 +200,12 @@ public final class Booster {
* @param evalNames name for eval dmatrixs, used for check results * @param evalNames name for eval dmatrixs, used for check results
* @param iter current eval iteration * @param iter current eval iteration
* @return eval information * @return eval information
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public String evalSet(long[] dHandles, String[] evalNames, int iter) { public String evalSet(long[] dHandles, String[] evalNames, int iter) throws XgboostError {
String evalInfo = XgboostJNI.XGBoosterEvalOneIter(handle, iter, dHandles, evalNames); String[] evalInfo = new String[1];
return evalInfo; ErrorHandle.checkCall(XgboostJNI.XGBoosterEvalOneIter(handle, iter, dHandles, evalNames, evalInfo));
return evalInfo[0];
} }
@ -197,8 +215,9 @@ public final class Booster {
* @param evalName * @param evalName
* @param iter * @param iter
* @return eval information * @return eval information
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public String eval(DMatrix evalMat, String evalName, int iter) { public String eval(DMatrix evalMat, String evalName, int iter) throws XgboostError {
DMatrix[] evalMats = new DMatrix[] {evalMat}; DMatrix[] evalMats = new DMatrix[] {evalMat};
String[] evalNames = new String[] {evalName}; String[] evalNames = new String[] {evalName};
return evalSet(evalMats, evalNames, iter); return evalSet(evalMats, evalNames, iter);
@ -212,7 +231,7 @@ public final class Booster {
* @param predLeaf * @param predLeaf
* @return predict results * @return predict results
*/ */
private synchronized float[][] pred(DMatrix data, boolean outPutMargin, long treeLimit, boolean predLeaf) { private synchronized float[][] pred(DMatrix data, boolean outPutMargin, long treeLimit, boolean predLeaf) throws XgboostError {
int optionMask = 0; int optionMask = 0;
if(outPutMargin) { if(outPutMargin) {
optionMask = 1; optionMask = 1;
@ -220,15 +239,16 @@ public final class Booster {
if(predLeaf) { if(predLeaf) {
optionMask = 2; optionMask = 2;
} }
float[] rawPredicts = XgboostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask, treeLimit); float[][] rawPredicts = new float[1][];
ErrorHandle.checkCall(XgboostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask, treeLimit, rawPredicts));
int row = (int) data.rowNum(); int row = (int) data.rowNum();
int col = (int) rawPredicts.length/row; int col = (int) rawPredicts[0].length/row;
float[][] predicts = new float[row][col]; float[][] predicts = new float[row][col];
int r,c; int r,c;
for(int i=0; i< rawPredicts.length; i++) { for(int i=0; i< rawPredicts[0].length; i++) {
r = i/col; r = i/col;
c = i%col; c = i%col;
predicts[r][c] = rawPredicts[i]; predicts[r][c] = rawPredicts[0][i];
} }
return predicts; return predicts;
} }
@ -237,8 +257,9 @@ public final class Booster {
* Predict with data * Predict with data
* @param data dmatrix storing the input * @param data dmatrix storing the input
* @return predict result * @return predict result
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public float[][] predict(DMatrix data) { public float[][] predict(DMatrix data) throws XgboostError {
return pred(data, false, 0, false); return pred(data, false, 0, false);
} }
@ -247,8 +268,9 @@ public final class Booster {
* @param data dmatrix storing the input * @param data dmatrix storing the input
* @param outPutMargin Whether to output the raw untransformed margin value. * @param outPutMargin Whether to output the raw untransformed margin value.
* @return predict result * @return predict result
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public float[][] predict(DMatrix data, boolean outPutMargin) { public float[][] predict(DMatrix data, boolean outPutMargin) throws XgboostError {
return pred(data, outPutMargin, 0, false); return pred(data, outPutMargin, 0, false);
} }
@ -258,8 +280,9 @@ public final class Booster {
* @param outPutMargin Whether to output the raw untransformed margin value. * @param outPutMargin Whether to output the raw untransformed margin value.
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees). * @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
* @return predict result * @return predict result
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public float[][] predict(DMatrix data, boolean outPutMargin, long treeLimit) { public float[][] predict(DMatrix data, boolean outPutMargin, long treeLimit) throws XgboostError {
return pred(data, outPutMargin, treeLimit, false); return pred(data, outPutMargin, treeLimit, false);
} }
@ -272,8 +295,9 @@ public final class Booster {
Note that the leaf index of a tree is unique per tree, so you may find leaf 1 Note that the leaf index of a tree is unique per tree, so you may find leaf 1
in both tree 1 and tree 0. in both tree 1 and tree 0.
* @return predict result * @return predict result
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public float[][] predict(DMatrix data , long treeLimit, boolean predLeaf) { public float[][] predict(DMatrix data , long treeLimit, boolean predLeaf) throws XgboostError {
return pred(data, false, treeLimit, predLeaf); return pred(data, false, treeLimit, predLeaf);
} }
@ -293,14 +317,16 @@ public final class Booster {
* get the dump of the model as a string array * get the dump of the model as a string array
* @param withStats Controls whether the split statistics are output. * @param withStats Controls whether the split statistics are output.
* @return dumped model information * @return dumped model information
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public String[] getDumpInfo(boolean withStats) { public String[] getDumpInfo(boolean withStats) throws XgboostError {
int statsFlag = 0; int statsFlag = 0;
if(withStats) { if(withStats) {
statsFlag = 1; statsFlag = 1;
} }
String[] modelInfos = XgboostJNI.XGBoosterDumpModel(handle, "", statsFlag); String[][] modelInfos = new String[1][];
return modelInfos; ErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos));
return modelInfos[0];
} }
/** /**
@ -308,14 +334,16 @@ public final class Booster {
* @param featureMap featureMap file * @param featureMap featureMap file
* @param withStats Controls whether the split statistics are output. * @param withStats Controls whether the split statistics are output.
* @return dumped model information * @return dumped model information
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public String[] getDumpInfo(String featureMap, boolean withStats) { public String[] getDumpInfo(String featureMap, boolean withStats) throws XgboostError {
int statsFlag = 0; int statsFlag = 0;
if(withStats) { if(withStats) {
statsFlag = 1; statsFlag = 1;
} }
String[] modelInfos = XgboostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag); String[][] modelInfos = new String[1][];
return modelInfos; ErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag, modelInfos));
return modelInfos[0];
} }
/** /**
@ -326,8 +354,9 @@ public final class Booster {
* @throws FileNotFoundException * @throws FileNotFoundException
* @throws UnsupportedEncodingException * @throws UnsupportedEncodingException
* @throws IOException * @throws IOException
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public void dumpModel(String modelPath, boolean withStats) throws FileNotFoundException, UnsupportedEncodingException, IOException { public void dumpModel(String modelPath, boolean withStats) throws FileNotFoundException, UnsupportedEncodingException, IOException, XgboostError {
File tf = new File(modelPath); File tf = new File(modelPath);
FileOutputStream out = new FileOutputStream(tf); FileOutputStream out = new FileOutputStream(tf);
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8")); BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
@ -352,8 +381,9 @@ public final class Booster {
* @throws FileNotFoundException * @throws FileNotFoundException
* @throws UnsupportedEncodingException * @throws UnsupportedEncodingException
* @throws IOException * @throws IOException
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public void dumpModel(String modelPath, String featureMap, boolean withStats) throws FileNotFoundException, UnsupportedEncodingException, IOException { public void dumpModel(String modelPath, String featureMap, boolean withStats) throws FileNotFoundException, UnsupportedEncodingException, IOException, XgboostError {
File tf = new File(modelPath); File tf = new File(modelPath);
FileOutputStream out = new FileOutputStream(tf); FileOutputStream out = new FileOutputStream(tf);
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8")); BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
@ -372,8 +402,9 @@ public final class Booster {
/** /**
* get importance of each feature * get importance of each feature
* @return featureMap key: feature index, value: feature importance score * @return featureMap key: feature index, value: feature importance score
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public Map<String, Integer> getFeatureScore() { public Map<String, Integer> getFeatureScore() throws XgboostError {
String[] modelInfos = getDumpInfo(false); String[] modelInfos = getDumpInfo(false);
Map<String, Integer> featureScore = new HashMap<>(); Map<String, Integer> featureScore = new HashMap<>();
for(String tree : modelInfos) { for(String tree : modelInfos) {
@ -400,8 +431,9 @@ public final class Booster {
* get importance of each feature * get importance of each feature
* @param featureMap file to save dumped model info * @param featureMap file to save dumped model info
* @return featureMap key: feature index, value: feature importance score * @return featureMap key: feature index, value: feature importance score
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public Map<String, Integer> getFeatureScore(String featureMap) { public Map<String, Integer> getFeatureScore(String featureMap) throws XgboostError {
String[] modelInfos = getDumpInfo(featureMap, false); String[] modelInfos = getDumpInfo(featureMap, false);
Map<String, Integer> featureScore = new HashMap<>(); Map<String, Integer> featureScore = new HashMap<>();
for(String tree : modelInfos) { for(String tree : modelInfos) {

View File

@ -18,6 +18,8 @@ package org.dmlc.xgboost4j;
import java.io.IOException; import java.io.IOException;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.dmlc.xgboost4j.util.ErrorHandle;
import org.dmlc.xgboost4j.util.XgboostError;
import org.dmlc.xgboost4j.util.Initializer; import org.dmlc.xgboost4j.util.Initializer;
import org.dmlc.xgboost4j.wrapper.XgboostJNI; import org.dmlc.xgboost4j.wrapper.XgboostJNI;
@ -50,9 +52,12 @@ public class DMatrix {
/** /**
* init DMatrix from file (svmlight format) * init DMatrix from file (svmlight format)
* @param dataPath * @param dataPath
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public DMatrix(String dataPath) { public DMatrix(String dataPath) throws XgboostError {
handle = XgboostJNI.XGDMatrixCreateFromFile(dataPath, 1); long[] out = new long[1];
ErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromFile(dataPath, 1, out));
handle = out[0];
} }
/** /**
@ -61,17 +66,20 @@ public class DMatrix {
* @param indices Indices (colIndexs for CSR or rowIndexs for CSC) * @param indices Indices (colIndexs for CSR or rowIndexs for CSC)
* @param data non zero values (sequence by row for CSR or by col for CSC) * @param data non zero values (sequence by row for CSR or by col for CSC)
* @param st sparse matrix type (CSR or CSC) * @param st sparse matrix type (CSR or CSC)
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) { public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XgboostError {
long[] out = new long[1];
if(st == SparseType.CSR) { if(st == SparseType.CSR) {
handle = XgboostJNI.XGDMatrixCreateFromCSR(headers, indices, data); ErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromCSR(headers, indices, data, out));
} }
else if(st == SparseType.CSC) { else if(st == SparseType.CSC) {
handle = XgboostJNI.XGDMatrixCreateFromCSC(headers, indices, data); ErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromCSC(headers, indices, data, out));
} }
else { else {
throw new UnknownError("unknow sparsetype"); throw new UnknownError("unknow sparsetype");
} }
handle = out[0];
} }
/** /**
@ -79,9 +87,12 @@ public class DMatrix {
* @param data data values * @param data data values
* @param nrow number of rows * @param nrow number of rows
* @param ncol number of columns * @param ncol number of columns
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public DMatrix(float[] data, int nrow, int ncol) { public DMatrix(float[] data, int nrow, int ncol) throws XgboostError {
handle = XgboostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f); long[] out = new long[1];
ErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f, out));
handle = out[0];
} }
/** /**
@ -98,33 +109,36 @@ public class DMatrix {
* set label of dmatrix * set label of dmatrix
* @param labels * @param labels
*/ */
public void setLabel(float[] labels) { public void setLabel(float[] labels) throws XgboostError {
XgboostJNI.XGDMatrixSetFloatInfo(handle, "label", labels); ErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "label", labels));
} }
/** /**
* set weight of each instance * set weight of each instance
* @param weights * @param weights
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public void setWeight(float[] weights) { public void setWeight(float[] weights) throws XgboostError {
XgboostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights); ErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights));
} }
/** /**
* if specified, xgboost will start from this init margin * if specified, xgboost will start from this init margin
* can be used to specify initial prediction to boost from * can be used to specify initial prediction to boost from
* @param baseMargin * @param baseMargin
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public void setBaseMargin(float[] baseMargin) { public void setBaseMargin(float[] baseMargin) throws XgboostError {
XgboostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin); ErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin));
} }
/** /**
* if specified, xgboost will start from this init margin * if specified, xgboost will start from this init margin
* can be used to specify initial prediction to boost from * can be used to specify initial prediction to boost from
* @param baseMargin * @param baseMargin
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public void setBaseMargin(float[][] baseMargin) { public void setBaseMargin(float[][] baseMargin) throws XgboostError {
float[] flattenMargin = flatten(baseMargin); float[] flattenMargin = flatten(baseMargin);
setBaseMargin(flattenMargin); setBaseMargin(flattenMargin);
} }
@ -132,42 +146,48 @@ public class DMatrix {
/** /**
* Set group sizes of DMatrix (used for ranking) * Set group sizes of DMatrix (used for ranking)
* @param group * @param group
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public void setGroup(int[] group) { public void setGroup(int[] group) throws XgboostError {
XgboostJNI.XGDMatrixSetGroup(handle, group); ErrorHandle.checkCall(XgboostJNI.XGDMatrixSetGroup(handle, group));
} }
private float[] getFloatInfo(String field) { private float[] getFloatInfo(String field) throws XgboostError {
float[] infos = XgboostJNI.XGDMatrixGetFloatInfo(handle, field); float[][] infos = new float[1][];
return infos; ErrorHandle.checkCall(XgboostJNI.XGDMatrixGetFloatInfo(handle, field, infos));
return infos[0];
} }
private int[] getIntInfo(String field) { private int[] getIntInfo(String field) throws XgboostError {
int[] infos = XgboostJNI.XGDMatrixGetUIntInfo(handle, field); int[][] infos = new int[1][];
return infos; ErrorHandle.checkCall(XgboostJNI.XGDMatrixGetUIntInfo(handle, field, infos));
return infos[0];
} }
/** /**
* get label values * get label values
* @return label * @return label
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public float[] getLabel() { public float[] getLabel() throws XgboostError {
return getFloatInfo("label"); return getFloatInfo("label");
} }
/** /**
* get weight of the DMatrix * get weight of the DMatrix
* @return weights * @return weights
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public float[] getWeight() { public float[] getWeight() throws XgboostError {
return getFloatInfo("weight"); return getFloatInfo("weight");
} }
/** /**
* get base margin of the DMatrix * get base margin of the DMatrix
* @return base margin * @return base margin
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public float[] getBaseMargin() { public float[] getBaseMargin() throws XgboostError {
return getFloatInfo("base_margin"); return getFloatInfo("base_margin");
} }
@ -175,9 +195,12 @@ public class DMatrix {
* Slice the DMatrix and return a new DMatrix that only contains `rowIndex`. * Slice the DMatrix and return a new DMatrix that only contains `rowIndex`.
* @param rowIndex * @param rowIndex
* @return sliced new DMatrix * @return sliced new DMatrix
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public DMatrix slice(int[] rowIndex) { public DMatrix slice(int[] rowIndex) throws XgboostError {
long sHandle = XgboostJNI.XGDMatrixSliceDMatrix(handle, rowIndex); long[] out = new long[1];
ErrorHandle.checkCall(XgboostJNI.XGDMatrixSliceDMatrix(handle, rowIndex, out));
long sHandle = out[0];
DMatrix sMatrix = new DMatrix(sHandle); DMatrix sMatrix = new DMatrix(sHandle);
return sMatrix; return sMatrix;
} }
@ -185,9 +208,12 @@ public class DMatrix {
/** /**
* get the row number of DMatrix * get the row number of DMatrix
* @return number of rows * @return number of rows
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public long rowNum() { public long rowNum() throws XgboostError {
return XgboostJNI.XGDMatrixNumRow(handle); long[] rowNum = new long[1];
ErrorHandle.checkCall(XgboostJNI.XGDMatrixNumRow(handle,rowNum));
return rowNum[0];
} }
/** /**

View File

@ -37,8 +37,9 @@ public class CVPack {
* @param dtrain train data * @param dtrain train data
* @param dtest test data * @param dtest test data
* @param params parameters * @param params parameters
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public CVPack(DMatrix dtrain, DMatrix dtest, Iterable<Map.Entry<String, Object>> params) { public CVPack(DMatrix dtrain, DMatrix dtest, Iterable<Map.Entry<String, Object>> params) throws XgboostError {
dmats = new DMatrix[] {dtrain, dtest}; dmats = new DMatrix[] {dtrain, dtest};
booster = new Booster(params, dmats); booster = new Booster(params, dmats);
names = new String[] {"train", "test"}; names = new String[] {"train", "test"};
@ -49,8 +50,9 @@ public class CVPack {
/** /**
* update one iteration * update one iteration
* @param iter iteration num * @param iter iteration num
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public void update(int iter) { public void update(int iter) throws XgboostError {
booster.update(dtrain, iter); booster.update(dtrain, iter);
} }
@ -58,8 +60,9 @@ public class CVPack {
* update one iteration * update one iteration
* @param iter iteration num * @param iter iteration num
* @param obj customized objective * @param obj customized objective
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public void update(int iter, IObjective obj) { public void update(int iter, IObjective obj) throws XgboostError {
booster.update(dtrain, iter, obj); booster.update(dtrain, iter, obj);
} }
@ -67,8 +70,9 @@ public class CVPack {
* evaluation * evaluation
* @param iter iteration num * @param iter iteration num
* @return * @return
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public String eval(int iter) { public String eval(int iter) throws XgboostError {
return booster.evalSet(dmats, names, iter); return booster.evalSet(dmats, names, iter);
} }
@ -77,8 +81,9 @@ public class CVPack {
* @param iter iteration num * @param iter iteration num
* @param eval customized eval * @param eval customized eval
* @return * @return
* @throws org.dmlc.xgboost4j.util.XgboostError
*/ */
public String eval(int iter, IEvaluation eval) { public String eval(int iter, IEvaluation eval) throws XgboostError {
return booster.evalSet(dmats, names, iter, eval); return booster.evalSet(dmats, names, iter, eval);
} }
} }

View File

@ -0,0 +1,50 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.util;
import java.io.IOException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
/**
* error handle for Xgboost
* @author hzx
*/
public class ErrorHandle {
private static final Log logger = LogFactory.getLog(ErrorHandle.class);
//load native library
static {
try {
Initializer.InitXgboost();
} catch (IOException ex) {
logger.error("load native library failed.");
logger.error(ex);
}
}
/**
* check the return value of C API
* @param ret return valud of xgboostJNI C API call
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public static void checkCall(int ret) throws XgboostError {
if(ret != 0) {
throw new XgboostError(XgboostJNI.XGBGetLastError());
}
}
}

View File

@ -47,7 +47,7 @@ public class Trainer {
* @return trained booster * @return trained booster
*/ */
public static Booster train(Iterable<Entry<String, Object>> params, DMatrix dtrain, int round, public static Booster train(Iterable<Entry<String, Object>> params, DMatrix dtrain, int round,
Iterable<Entry<String, DMatrix>> watchs, IObjective obj, IEvaluation eval) { Iterable<Entry<String, DMatrix>> watchs, IObjective obj, IEvaluation eval) throws XgboostError {
//collect eval matrixs //collect eval matrixs
String[] evalNames; String[] evalNames;
@ -112,7 +112,7 @@ public class Trainer {
* @param eval customized evaluation (set to null if not used) * @param eval customized evaluation (set to null if not used)
* @return evaluation history * @return evaluation history
*/ */
public static String[] crossValiation(Iterable<Entry<String, Object>> params, DMatrix data, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) { public static String[] crossValiation(Iterable<Entry<String, Object>> params, DMatrix data, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) throws XgboostError {
CVPack[] cvPacks = makeNFold(data, nfold, params, metrics); CVPack[] cvPacks = makeNFold(data, nfold, params, metrics);
String[] evalHist = new String[round]; String[] evalHist = new String[round];
String[] results = new String[cvPacks.length]; String[] results = new String[cvPacks.length];
@ -149,7 +149,7 @@ public class Trainer {
* @param evalMetrics Evaluation metrics * @param evalMetrics Evaluation metrics
* @return CV package array * @return CV package array
*/ */
public static CVPack[] makeNFold(DMatrix data, int nfold, Iterable<Entry<String, Object>> params, String[] evalMetrics) { public static CVPack[] makeNFold(DMatrix data, int nfold, Iterable<Entry<String, Object>> params, String[] evalMetrics) throws XgboostError {
List<Integer> samples = genRandPermutationNums(0, (int) data.rowNum()); List<Integer> samples = genRandPermutationNums(0, (int) data.rowNum());
int step = samples.size()/nfold; int step = samples.size()/nfold;
int[] testSlice = new int[step]; int[] testSlice = new int[step];

View File

@ -0,0 +1,26 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.util;
/**
* custom error class for xgboost
* @author hzx
*/
public class XgboostError extends Exception{
public XgboostError(String message) {
super(message);
}
}

View File

@ -17,32 +17,34 @@ package org.dmlc.xgboost4j.wrapper;
/** /**
* xgboost jni wrapper functions for xgboost_wrapper.h * xgboost jni wrapper functions for xgboost_wrapper.h
* change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster
* @author hzx * @author hzx
*/ */
public class XgboostJNI { public class XgboostJNI {
public final static native long XGDMatrixCreateFromFile(String fname, int silent); public final static native String XGBGetLastError();
public final static native long XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data); public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
public final static native long XGDMatrixCreateFromCSC(long[] colptr, int[] indices, float[] data); public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data, long[] out);
public final static native long XGDMatrixCreateFromMat(float[] data, int nrow, int ncol, float missing); public final static native int XGDMatrixCreateFromCSC(long[] colptr, int[] indices, float[] data, long[] out);
public final static native long XGDMatrixSliceDMatrix(long handle, int[] idxset); public final static native int XGDMatrixCreateFromMat(float[] data, int nrow, int ncol, float missing, long[] out);
public final static native void XGDMatrixFree(long handle); public final static native int XGDMatrixSliceDMatrix(long handle, int[] idxset, long[] out);
public final static native void XGDMatrixSaveBinary(long handle, String fname, int silent); public final static native int XGDMatrixFree(long handle);
public final static native void XGDMatrixSetFloatInfo(long handle, String field, float[] array); public final static native int XGDMatrixSaveBinary(long handle, String fname, int silent);
public final static native void XGDMatrixSetUIntInfo(long handle, String field, int[] array); public final static native int XGDMatrixSetFloatInfo(long handle, String field, float[] array);
public final static native void XGDMatrixSetGroup(long handle, int[] group); public final static native int XGDMatrixSetUIntInfo(long handle, String field, int[] array);
public final static native float[] XGDMatrixGetFloatInfo(long handle, String field); public final static native int XGDMatrixSetGroup(long handle, int[] group);
public final static native int[] XGDMatrixGetUIntInfo(long handle, String filed); public final static native int XGDMatrixGetFloatInfo(long handle, String field, float[][] info);
public final static native long XGDMatrixNumRow(long handle); public final static native int XGDMatrixGetUIntInfo(long handle, String filed, int[][] info);
public final static native long XGBoosterCreate(long[] handles); public final static native int XGDMatrixNumRow(long handle, long[] row);
public final static native void XGBoosterFree(long handle); public final static native int XGBoosterCreate(long[] handles, long[] out);
public final static native void XGBoosterSetParam(long handle, String name, String value); public final static native int XGBoosterFree(long handle);
public final static native void XGBoosterUpdateOneIter(long handle, int iter, long dtrain); public final static native int XGBoosterSetParam(long handle, String name, String value);
public final static native void XGBoosterBoostOneIter(long handle, long dtrain, float[] grad, float[] hess); public final static native int XGBoosterUpdateOneIter(long handle, int iter, long dtrain);
public final static native String XGBoosterEvalOneIter(long handle, int iter, long[] dmats, String[] evnames); public final static native int XGBoosterBoostOneIter(long handle, long dtrain, float[] grad, float[] hess);
public final static native float[] XGBoosterPredict(long handle, long dmat, int option_mask, long ntree_limit); public final static native int XGBoosterEvalOneIter(long handle, int iter, long[] dmats, String[] evnames, String[] eval_info);
public final static native void XGBoosterLoadModel(long handle, String fname); public final static native int XGBoosterPredict(long handle, long dmat, int option_mask, long ntree_limit, float[][] predicts);
public final static native void XGBoosterSaveModel(long handle, String fname); public final static native int XGBoosterLoadModel(long handle, String fname);
public final static native void XGBoosterLoadModelFromBuffer(long handle, long buf, long len); public final static native int XGBoosterSaveModel(long handle, String fname);
public final static native String XGBoosterGetModelRaw(long handle); public final static native int XGBoosterLoadModelFromBuffer(long handle, long buf, long len);
public final static native String[] XGBoosterDumpModel(long handle, String fmap, int with_stats); public final static native int XGBoosterGetModelRaw(long handle, String[] out_string);
public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats, String[][] out_strings);
} }

View File

@ -16,21 +16,34 @@
#include "../wrapper/xgboost_wrapper.h" #include "../wrapper/xgboost_wrapper.h"
#include "xgboost4j_wrapper.h" #include "xgboost4j_wrapper.h"
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromFile JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBGetLastError
(JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent) { (JNIEnv *jenv, jclass jcls) {
jlong jresult = 0 ; jstring jresult = 0 ;
char* result = 0;
result = (char *)XGBGetLastError();
if (result) jresult = jenv->NewStringUTF((const char *)result);
return jresult;
}
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromFile
(JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent, jlongArray jout) {
jint jresult = 0 ;
char *fname = (char *) 0 ; char *fname = (char *) 0 ;
int silent; int silent;
void *result = 0 ; void* result[1];
fname = 0; unsigned long out[1];
if (jfname) {
fname = (char *)jenv->GetStringUTFChars(jfname, 0); fname = (char *)jenv->GetStringUTFChars(jfname, 0);
if (!fname) return 0;
}
silent = (int)jsilent; silent = (int)jsilent;
result = (void *)XGDMatrixCreateFromFile((char const *)fname, silent); jresult = (jint) XGDMatrixCreateFromFile((char const *)fname, silent, result);
*(void **)&jresult = result;
*(void **)&out[0] = *result;
if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname); if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname);
jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) out);
return jresult; return jresult;
} }
@ -39,12 +52,13 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCrea
* Method: XGDMatrixCreateFromCSR * Method: XGDMatrixCreateFromCSR
* Signature: ([J[J[F)J * Signature: ([J[J[F)J
*/ */
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSR JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSR
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata) { (JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) {
jlong jresult = 0 ; jint jresult = 0 ;
bst_ulong nindptr ; bst_ulong nindptr ;
bst_ulong nelem; bst_ulong nelem;
void *result = 0 ; void *result[1];
unsigned long out[1];
jlong* indptr = jenv->GetLongArrayElements(jindptr, 0); jlong* indptr = jenv->GetLongArrayElements(jindptr, 0);
jint* indices = jenv->GetIntArrayElements(jindices, 0); jint* indices = jenv->GetIntArrayElements(jindices, 0);
@ -52,8 +66,9 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCrea
nindptr = (bst_ulong)jenv->GetArrayLength(jindptr); nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
nelem = (bst_ulong)jenv->GetArrayLength(jdata); nelem = (bst_ulong)jenv->GetArrayLength(jdata);
result = (void *)XGDMatrixCreateFromCSR((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem); jresult = (jint) XGDMatrixCreateFromCSR((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, result);
*(void **)&jresult = result; *(void **)&out[0] = *result;
jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) out);
//release //release
jenv->ReleaseLongArrayElements(jindptr, indptr, 0); jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
@ -68,12 +83,13 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCrea
* Method: XGDMatrixCreateFromCSC * Method: XGDMatrixCreateFromCSC
* Signature: ([J[J[F)J * Signature: ([J[J[F)J
*/ */
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSC JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSC
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata) { (JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) {
jlong jresult = 0 ; jint jresult = 0;
bst_ulong nindptr ; bst_ulong nindptr ;
bst_ulong nelem; bst_ulong nelem;
void *result = 0 ; void *result[1];
unsigned long out[1];
jlong* indptr = jenv->GetLongArrayElements(jindptr, NULL); jlong* indptr = jenv->GetLongArrayElements(jindptr, NULL);
jint* indices = jenv->GetIntArrayElements(jindices, 0); jint* indices = jenv->GetIntArrayElements(jindices, 0);
@ -81,8 +97,9 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCrea
nindptr = (bst_ulong)jenv->GetArrayLength(jindptr); nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
nelem = (bst_ulong)jenv->GetArrayLength(jdata); nelem = (bst_ulong)jenv->GetArrayLength(jdata);
result = (void *)XGDMatrixCreateFromCSC((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem); jresult = (jint) XGDMatrixCreateFromCSC((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, result);
*(void **)&jresult = result; *(void **)&out[0] = *result;
jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) out);
//release //release
jenv->ReleaseLongArrayElements(jindptr, indptr, 0); jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
@ -97,21 +114,24 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCrea
* Method: XGDMatrixCreateFromMat * Method: XGDMatrixCreateFromMat
* Signature: ([FIIF)J * Signature: ([FIIF)J
*/ */
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromMat JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromMat
(JNIEnv *jenv, jclass jcls, jfloatArray jdata, jint jnrow, jint jncol, jfloat jmiss) { (JNIEnv *jenv, jclass jcls, jfloatArray jdata, jint jnrow, jint jncol, jfloat jmiss, jlongArray jout) {
jlong jresult = 0 ; jint jresult = 0 ;
bst_ulong nrow ; bst_ulong nrow ;
bst_ulong ncol ; bst_ulong ncol ;
float miss ; float miss ;
void *result = 0 ; void *result[1];
unsigned long out[1];
jfloat* data = jenv->GetFloatArrayElements(jdata, 0); jfloat* data = jenv->GetFloatArrayElements(jdata, 0);
nrow = (bst_ulong)jnrow; nrow = (bst_ulong)jnrow;
ncol = (bst_ulong)jncol; ncol = (bst_ulong)jncol;
miss = (float)jmiss; miss = (float)jmiss;
result = (void *)XGDMatrixCreateFromMat((float const *)data, nrow, ncol, miss);
*(void **)&jresult = result; jresult = (jint) XGDMatrixCreateFromMat((float const *)data, nrow, ncol, miss, result);
*(void **)&out[0] = *result;
jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) out);
//release //release
jenv->ReleaseFloatArrayElements(jdata, data, 0); jenv->ReleaseFloatArrayElements(jdata, data, 0);
@ -124,19 +144,21 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCrea
* Method: XGDMatrixSliceDMatrix * Method: XGDMatrixSliceDMatrix
* Signature: (J[I)J * Signature: (J[I)J
*/ */
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSliceDMatrix JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSliceDMatrix
(JNIEnv *jenv, jclass jcls, jlong jhandle, jintArray jindexset) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jintArray jindexset, jlongArray jout) {
jlong jresult = 0 ; jint jresult = 0 ;
void *handle = (void *) 0 ; void *handle = (void *) 0 ;
bst_ulong len; bst_ulong len;
void *result = 0 ; void *result[1];
unsigned long out[1];
jint* indexset = jenv->GetIntArrayElements(jindexset, 0); jint* indexset = jenv->GetIntArrayElements(jindexset, 0);
handle = *(void **)&jhandle; handle = *(void **)&jhandle;
len = (bst_ulong)jenv->GetArrayLength(jindexset); len = (bst_ulong)jenv->GetArrayLength(jindexset);
result = (void *)XGDMatrixSliceDMatrix(handle, (int const *)indexset, len); jresult = (jint) XGDMatrixSliceDMatrix(handle, (int const *)indexset, len, result);
*(void **)&jresult = result; *(void **)&out[0] = *result;
jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) out);
//release //release
jenv->ReleaseIntArrayElements(jindexset, indexset, 0); jenv->ReleaseIntArrayElements(jindexset, indexset, 0);
@ -149,11 +171,13 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSlic
* Method: XGDMatrixFree * Method: XGDMatrixFree
* Signature: (J)V * Signature: (J)V
*/ */
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixFree JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixFree
(JNIEnv *jenv, jclass jcls, jlong jhandle) { (JNIEnv *jenv, jclass jcls, jlong jhandle) {
jint jresult = 0;
void *handle = (void *) 0 ; void *handle = (void *) 0 ;
handle = *(void **)&jhandle; handle = *(void **)&jhandle;
XGDMatrixFree(handle); jresult = (jint) XGDMatrixFree(handle);
return jresult;
} }
/* /*
@ -161,20 +185,21 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixFree
* Method: XGDMatrixSaveBinary * Method: XGDMatrixSaveBinary
* Signature: (JLjava/lang/String;I)V * Signature: (JLjava/lang/String;I)V
*/ */
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSaveBinary JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSaveBinary
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname, jint jsilent) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname, jint jsilent) {
jint jresult = 0;
void *handle = (void *) 0 ; void *handle = (void *) 0 ;
char *fname = (char *) 0 ; char *fname = (char *) 0 ;
int silent ; int silent ;
handle = *(void **)&jhandle; handle = *(void **)&jhandle;
fname = 0; fname = 0;
if (jfname) { fname = (char *)jenv->GetStringUTFChars(jfname, 0);
fname = (char *)jenv->GetStringUTFChars(jfname, 0);
if (!fname) return ;
}
silent = (int)jsilent; silent = (int)jsilent;
XGDMatrixSaveBinary(handle, (char const *)fname, silent); jresult = (jint) XGDMatrixSaveBinary(handle, (char const *)fname, silent);
if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname); if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname);
return jresult;
} }
/* /*
@ -182,27 +207,28 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSaveB
* Method: XGDMatrixSetFloatInfo * Method: XGDMatrixSetFloatInfo
* Signature: (JLjava/lang/String;[F)V * Signature: (JLjava/lang/String;[F)V
*/ */
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetFloatInfo JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetFloatInfo
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jfloatArray jarray) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jfloatArray jarray) {
jint jresult = 0;
void *handle = (void *) 0 ; void *handle = (void *) 0 ;
char *field = (char *) 0 ; char *field = (char *) 0 ;
bst_ulong len; bst_ulong len;
handle = *(void **)&jhandle; handle = *(void **)&jhandle;
field = 0;
if (jfield) { field = (char *)jenv->GetStringUTFChars(jfield, 0);
field = (char *)jenv->GetStringUTFChars(jfield, 0);
if (!field) return ;
}
jfloat* array = jenv->GetFloatArrayElements(jarray, NULL); jfloat* array = jenv->GetFloatArrayElements(jarray, NULL);
len = (bst_ulong)jenv->GetArrayLength(jarray); len = (bst_ulong)jenv->GetArrayLength(jarray);
XGDMatrixSetFloatInfo(handle, (char const *)field, (float const *)array, len); jresult = (jint) XGDMatrixSetFloatInfo(handle, (char const *)field, (float const *)array, len);
//release //release
if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field); if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field);
jenv->ReleaseFloatArrayElements(jarray, array, 0); jenv->ReleaseFloatArrayElements(jarray, array, 0);
return jresult;
} }
/* /*
@ -210,25 +236,26 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetFl
* Method: XGDMatrixSetUIntInfo * Method: XGDMatrixSetUIntInfo
* Signature: (JLjava/lang/String;[I)V * Signature: (JLjava/lang/String;[I)V
*/ */
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetUIntInfo JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetUIntInfo
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jintArray jarray) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jintArray jarray) {
jint jresult = 0;
void *handle = (void *) 0 ; void *handle = (void *) 0 ;
char *field = (char *) 0 ; char *field = (char *) 0 ;
bst_ulong len ; bst_ulong len ;
handle = *(void **)&jhandle; handle = *(void **)&jhandle;
field = 0; field = 0;
if (jfield) { field = (char *)jenv->GetStringUTFChars(jfield, 0);
field = (char *)jenv->GetStringUTFChars(jfield, 0);
if (!field) return ;
}
jint* array = jenv->GetIntArrayElements(jarray, NULL); jint* array = jenv->GetIntArrayElements(jarray, NULL);
len = (bst_ulong)jenv->GetArrayLength(jarray); len = (bst_ulong)jenv->GetArrayLength(jarray);
XGDMatrixSetUIntInfo(handle, (char const *)field, (unsigned int const *)array, len); jresult = (jint) XGDMatrixSetUIntInfo(handle, (char const *)field, (unsigned int const *)array, len);
//release //release
if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field); if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field);
jenv->ReleaseIntArrayElements(jarray, array, 0); jenv->ReleaseIntArrayElements(jarray, array, 0);
return jresult;
} }
/* /*
@ -236,8 +263,9 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetUI
* Method: XGDMatrixSetGroup * Method: XGDMatrixSetGroup
* Signature: (J[I)V * Signature: (J[I)V
*/ */
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGroup JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGroup
(JNIEnv * jenv, jclass jcls, jlong jhandle, jintArray jarray) { (JNIEnv * jenv, jclass jcls, jlong jhandle, jintArray jarray) {
jint jresult = 0;
void *handle = (void *) 0 ; void *handle = (void *) 0 ;
bst_ulong len ; bst_ulong len ;
@ -245,11 +273,12 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGr
jint* array = jenv->GetIntArrayElements(jarray, NULL); jint* array = jenv->GetIntArrayElements(jarray, NULL);
len = (bst_ulong)jenv->GetArrayLength(jarray); len = (bst_ulong)jenv->GetArrayLength(jarray);
XGDMatrixSetGroup(handle, (unsigned int const *)array, len); jresult = (jint) XGDMatrixSetGroup(handle, (unsigned int const *)array, len);
//release //release
jenv->ReleaseIntArrayElements(jarray, array, 0); jenv->ReleaseIntArrayElements(jarray, array, 0);
return jresult;
} }
/* /*
@ -257,28 +286,31 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGr
* Method: XGDMatrixGetFloatInfo * Method: XGDMatrixGetFloatInfo
* Signature: (JLjava/lang/String;)[F * Signature: (JLjava/lang/String;)[F
*/ */
JNIEXPORT jfloatArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetFloatInfo JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetFloatInfo
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) {
jint jresult = 0;
void *handle = (void *) 0 ; void *handle = (void *) 0 ;
char *field = (char *) 0 ; char *field = (char *) 0 ;
bst_ulong len[1]; bst_ulong len[1];
*len = 0; *len = 0;
float *result = 0 ; float *result[1];
handle = *(void **)&jhandle; handle = *(void **)&jhandle;
field = 0; field = 0;
if (jfield) { if (jfield) {
field = (char *)jenv->GetStringUTFChars(jfield, 0); field = (char *)jenv->GetStringUTFChars(jfield, 0);
if (!field) return 0; if (!field) return 0;
} }
result = (float *)XGDMatrixGetFloatInfo((void const *)handle, (char const *)field, len); jresult = (jint) XGDMatrixGetFloatInfo(handle, (char const *)field, len, (const float **) result);
if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field); if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field);
jsize jlen = (jsize)*len; jsize jlen = (jsize)*len;
jfloatArray jresult = jenv->NewFloatArray(jlen); jfloatArray jarray = jenv->NewFloatArray(jlen);
jenv->SetFloatArrayRegion(jresult, 0, jlen, (jfloat *)result); jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) *result);
jenv->SetObjectArrayElement(jout, 0, (jobject) jarray);
return jresult; return jresult;
} }
@ -287,28 +319,26 @@ JNIEXPORT jfloatArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatr
* Method: XGDMatrixGetUIntInfo * Method: XGDMatrixGetUIntInfo
* Signature: (JLjava/lang/String;)[I * Signature: (JLjava/lang/String;)[I
*/ */
JNIEXPORT jintArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetUIntInfo JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetUIntInfo
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) {
jint jresult = 0;
void *handle = (void *) 0 ; void *handle = (void *) 0 ;
char *field = (char *) 0 ; char *field = (char *) 0 ;
bst_ulong len[1]; bst_ulong len[1];
*len = 0; *len = 0;
unsigned int *result = 0 ; unsigned int *result[1];
handle = *(void **)&jhandle; handle = *(void **)&jhandle;
field = 0; field = (char *)jenv->GetStringUTFChars(jfield, 0);
if (jfield) {
field = (char *)jenv->GetStringUTFChars(jfield, 0);
if (!field) return 0;
}
result = (unsigned int *)XGDMatrixGetUIntInfo((void const *)handle, (char const *)field, len); jresult = (jint) XGDMatrixGetUIntInfo(handle, (char const *)field, len, (const unsigned int **) result);
if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field); if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field);
jsize jlen = (jsize)*len; jsize jlen = (jsize)*len;
jintArray jresult = jenv->NewIntArray(jlen); jintArray jarray = jenv->NewIntArray(jlen);
jenv->SetIntArrayRegion(jresult, 0, jlen, (jint *)result); jenv->SetIntArrayRegion(jarray, 0, jlen, (jint *) *result);
jenv->SetObjectArrayElement(jout, 0, jarray);
return jresult; return jresult;
} }
@ -317,14 +347,14 @@ JNIEXPORT jintArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrix
* Method: XGDMatrixNumRow * Method: XGDMatrixNumRow
* Signature: (J)J * Signature: (J)J
*/ */
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixNumRow JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixNumRow
(JNIEnv *jenv, jclass jcls, jlong jhandle) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jlongArray jout) {
jlong jresult = 0 ; jint jresult = 0 ;
void *handle = (void *) 0 ; void *handle = (void *) 0 ;
bst_ulong result; bst_ulong result[1];
handle = *(void **)&jhandle; handle = *(void **)&jhandle;
result = (bst_ulong)XGDMatrixNumRow((void const *)handle); jresult = (jint) XGDMatrixNumRow(handle, result);
jresult = (jlong)result; jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) result);
return jresult; return jresult;
} }
@ -333,13 +363,14 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixNumR
* Method: XGBoosterCreate * Method: XGBoosterCreate
* Signature: ([J)J * Signature: ([J)J
*/ */
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCreate JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCreate
(JNIEnv *jenv, jclass jcls, jlongArray jhandles) { (JNIEnv *jenv, jclass jcls, jlongArray jhandles, jlongArray jout) {
jlong jresult = 0 ; jint jresult = 0;
void **handles = 0; void **handles = 0;
bst_ulong len = 0; bst_ulong len = 0;
void *result = 0 ; void *result[1];
jlong* cjhandles = 0; jlong* cjhandles = 0;
unsigned long out[1];
if(jhandles) { if(jhandles) {
len = (bst_ulong)jenv->GetArrayLength(jhandles); len = (bst_ulong)jenv->GetArrayLength(jhandles);
@ -351,7 +382,7 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCrea
} }
} }
result = (void *)XGBoosterCreate(handles, len); jresult = (jint) XGBoosterCreate(handles, len, result);
//release //release
if(jhandles) { if(jhandles) {
@ -359,7 +390,9 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCrea
jenv->ReleaseLongArrayElements(jhandles, cjhandles, 0); jenv->ReleaseLongArrayElements(jhandles, cjhandles, 0);
} }
*(void **)&jresult = result; *(void **)&out[0] = *result;
jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) out);
return jresult; return jresult;
} }
@ -368,11 +401,11 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCrea
* Method: XGBoosterFree * Method: XGBoosterFree
* Signature: (J)V * Signature: (J)V
*/ */
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterFree JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterFree
(JNIEnv *jenv, jclass jcls, jlong jhandle) { (JNIEnv *jenv, jclass jcls, jlong jhandle) {
void *handle = (void *) 0 ; void *handle = (void *) 0 ;
handle = *(void **)&jhandle; handle = *(void **)&jhandle;
XGBoosterFree(handle); return (jint) XGBoosterFree(handle);
} }
@ -381,27 +414,22 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterFree
* Method: XGBoosterSetParam * Method: XGBoosterSetParam
* Signature: (JLjava/lang/String;Ljava/lang/String;)V * Signature: (JLjava/lang/String;Ljava/lang/String;)V
*/ */
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSetParam JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSetParam
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jname, jstring jvalue) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jname, jstring jvalue) {
jint jresult = -1;
void *handle = (void *) 0 ; void *handle = (void *) 0 ;
char *name = (char *) 0 ; char *name = (char *) 0 ;
char *value = (char *) 0 ; char *value = (char *) 0 ;
handle = *(void **)&jhandle; handle = *(void **)&jhandle;
name = 0; name = (char *)jenv->GetStringUTFChars(jname, 0);
if (jname) { value = (char *)jenv->GetStringUTFChars(jvalue, 0);
name = (char *)jenv->GetStringUTFChars(jname, 0);
if (!name) return ;
}
value = 0; jresult = (jint) XGBoosterSetParam(handle, (char const *)name, (char const *)value);
if (jvalue) {
value = (char *)jenv->GetStringUTFChars(jvalue, 0);
if (!value) return ;
}
XGBoosterSetParam(handle, (char const *)name, (char const *)value);
if (name) jenv->ReleaseStringUTFChars(jname, (const char *)name); if (name) jenv->ReleaseStringUTFChars(jname, (const char *)name);
if (value) jenv->ReleaseStringUTFChars(jvalue, (const char *)value); if (value) jenv->ReleaseStringUTFChars(jvalue, (const char *)value);
return jresult;
} }
/* /*
@ -409,7 +437,7 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSetPa
* Method: XGBoosterUpdateOneIter * Method: XGBoosterUpdateOneIter
* Signature: (JIJ)V * Signature: (JIJ)V
*/ */
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdateOneIter JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdateOneIter
(JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlong jdtrain) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlong jdtrain) {
void *handle = (void *) 0 ; void *handle = (void *) 0 ;
int iter ; int iter ;
@ -417,7 +445,7 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdat
handle = *(void **)&jhandle; handle = *(void **)&jhandle;
iter = (int)jiter; iter = (int)jiter;
dtrain = *(void **)&jdtrain; dtrain = *(void **)&jdtrain;
XGBoosterUpdateOneIter(handle, iter, dtrain); return (jint) XGBoosterUpdateOneIter(handle, iter, dtrain);
} }
/* /*
@ -425,8 +453,9 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdat
* Method: XGBoosterBoostOneIter * Method: XGBoosterBoostOneIter
* Signature: (JJ[F[F)V * Signature: (JJ[F[F)V
*/ */
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoostOneIter JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoostOneIter
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jfloatArray jgrad, jfloatArray jhess) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jfloatArray jgrad, jfloatArray jhess) {
jint jresult = 0;
void *handle = (void *) 0 ; void *handle = (void *) 0 ;
void *dtrain = (void *) 0 ; void *dtrain = (void *) 0 ;
bst_ulong len ; bst_ulong len ;
@ -436,11 +465,13 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoost
jfloat* grad = jenv->GetFloatArrayElements(jgrad, 0); jfloat* grad = jenv->GetFloatArrayElements(jgrad, 0);
jfloat* hess = jenv->GetFloatArrayElements(jhess, 0); jfloat* hess = jenv->GetFloatArrayElements(jhess, 0);
len = (bst_ulong)jenv->GetArrayLength(jgrad); len = (bst_ulong)jenv->GetArrayLength(jgrad);
XGBoosterBoostOneIter(handle, dtrain, grad, hess, len); jresult = (jint) XGBoosterBoostOneIter(handle, dtrain, grad, hess, len);
//release //release
jenv->ReleaseFloatArrayElements(jgrad, grad, 0); jenv->ReleaseFloatArrayElements(jgrad, grad, 0);
jenv->ReleaseFloatArrayElements(jhess, hess, 0); jenv->ReleaseFloatArrayElements(jhess, hess, 0);
return jresult;
} }
/* /*
@ -448,15 +479,15 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoost
* Method: XGBoosterEvalOneIter * Method: XGBoosterEvalOneIter
* Signature: (JI[J[Ljava/lang/String;)Ljava/lang/String; * Signature: (JI[J[Ljava/lang/String;)Ljava/lang/String;
*/ */
JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEvalOneIter JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEvalOneIter
(JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlongArray jdmats, jobjectArray jevnames) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlongArray jdmats, jobjectArray jevnames, jobjectArray jout) {
jstring jresult = 0 ; jint jresult = 0 ;
void *handle = (void *) 0 ; void *handle = (void *) 0 ;
int iter ; int iter ;
void **dmats = 0; void **dmats = 0;
char **evnames = 0; char **evnames = 0;
bst_ulong len ; bst_ulong len ;
char *result = 0 ; char *result[1];
handle = *(void **)&jhandle; handle = *(void **)&jhandle;
iter = (int)jiter; iter = (int)jiter;
@ -480,7 +511,7 @@ JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEv
evnames[i] = (char *)jenv->GetStringUTFChars(jevname, 0); evnames[i] = (char *)jenv->GetStringUTFChars(jevname, 0);
} }
result = (char *)XGBoosterEvalOneIter(handle, iter, dmats, (char const *(*))evnames, len); jresult = (jint) XGBoosterEvalOneIter(handle, iter, dmats, (char const *(*))evnames, len, (const char **) result);
if(len > 0) { if(len > 0) {
delete[] dmats; delete[] dmats;
@ -493,7 +524,9 @@ JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEv
jenv->ReleaseLongArrayElements(jdmats, cjdmats, 0); jenv->ReleaseLongArrayElements(jdmats, cjdmats, 0);
} }
if (result) jresult = jenv->NewStringUTF((const char *)result); jstring jinfo = 0;
if (*result) jinfo = jenv->NewStringUTF((const char *) *result);
jenv->SetObjectArrayElement(jout, 0, jinfo);
return jresult; return jresult;
} }
@ -503,26 +536,29 @@ JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEv
* Method: XGBoosterPredict * Method: XGBoosterPredict
* Signature: (JJIJ)[F * Signature: (JJIJ)[F
*/ */
JNIEXPORT jfloatArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterPredict JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterPredict
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdmat, jint joption_mask, jlong jntree_limit) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdmat, jint joption_mask, jlong jntree_limit, jobjectArray jout) {
jint jresult = 0;
void *handle = (void *) 0 ; void *handle = (void *) 0 ;
void *dmat = (void *) 0 ; void *dmat = (void *) 0 ;
int option_mask ; int option_mask ;
unsigned int ntree_limit ; unsigned int ntree_limit ;
bst_ulong len[1]; bst_ulong len[1];
*len = 0; *len = 0;
float *result = 0 ; float *result[1];
handle = *(void **)&jhandle; handle = *(void **)&jhandle;
dmat = *(void **)&jdmat; dmat = *(void **)&jdmat;
option_mask = (int)joption_mask; option_mask = (int)joption_mask;
ntree_limit = (unsigned int)jntree_limit; ntree_limit = (unsigned int)jntree_limit;
result = (float *)XGBoosterPredict(handle, dmat, option_mask, ntree_limit, len); jresult = (jint) XGBoosterPredict(handle, dmat, option_mask, ntree_limit, len, (const float **) result);
jsize jlen = (jsize)*len; jsize jlen = (jsize)*len;
jfloatArray jresult = jenv->NewFloatArray(jlen); jfloatArray jarray = jenv->NewFloatArray(jlen);
jenv->SetFloatArrayRegion(jresult, 0, jlen, (jfloat *)result); jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) *result);
jenv->SetObjectArrayElement(jout, 0, jarray);
return jresult; return jresult;
} }
@ -531,18 +567,20 @@ JNIEXPORT jfloatArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoost
* Method: XGBoosterLoadModel * Method: XGBoosterLoadModel
* Signature: (JLjava/lang/String;)V * Signature: (JLjava/lang/String;)V
*/ */
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModel JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModel
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) {
jint jresult = 0;
void *handle = (void *) 0 ; void *handle = (void *) 0 ;
char *fname = (char *) 0 ; char *fname = (char *) 0 ;
handle = *(void **)&jhandle; handle = *(void **)&jhandle;
fname = 0;
if (jfname) { fname = (char *)jenv->GetStringUTFChars(jfname, 0);
fname = (char *)jenv->GetStringUTFChars(jfname, 0);
if (!fname) return ;
} jresult = (jint) XGBoosterLoadModel(handle,(char const *)fname);
XGBoosterLoadModel(handle,(char const *)fname);
if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname); if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname);
return jresult;
} }
/* /*
@ -550,18 +588,19 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadM
* Method: XGBoosterSaveModel * Method: XGBoosterSaveModel
* Signature: (JLjava/lang/String;)V * Signature: (JLjava/lang/String;)V
*/ */
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSaveModel JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSaveModel
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) {
jint jresult = 0;
void *handle = (void *) 0 ; void *handle = (void *) 0 ;
char *fname = (char *) 0 ; char *fname = (char *) 0 ;
handle = *(void **)&jhandle; handle = *(void **)&jhandle;
fname = 0; fname = 0;
if (jfname) { fname = (char *)jenv->GetStringUTFChars(jfname, 0);
fname = (char *)jenv->GetStringUTFChars(jfname, 0);
if (!fname) return ; jresult = (jint) XGBoosterSaveModel(handle, (char const *)fname);
}
XGBoosterSaveModel(handle, (char const *)fname);
if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname); if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname);
return jresult;
} }
/* /*
@ -569,7 +608,7 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSaveM
* Method: XGBoosterLoadModelFromBuffer * Method: XGBoosterLoadModelFromBuffer
* Signature: (JJJ)V * Signature: (JJJ)V
*/ */
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModelFromBuffer JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModelFromBuffer
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jbuf, jlong jlen) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jbuf, jlong jlen) {
void *handle = (void *) 0 ; void *handle = (void *) 0 ;
void *buf = (void *) 0 ; void *buf = (void *) 0 ;
@ -577,7 +616,7 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadM
handle = *(void **)&jhandle; handle = *(void **)&jhandle;
buf = *(void **)&jbuf; buf = *(void **)&jbuf;
len = (bst_ulong)jlen; len = (bst_ulong)jlen;
XGBoosterLoadModelFromBuffer(handle, (void const *)buf, len); return (jint) XGBoosterLoadModelFromBuffer(handle, (void const *)buf, len);
} }
/* /*
@ -585,17 +624,21 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadM
* Method: XGBoosterGetModelRaw * Method: XGBoosterGetModelRaw
* Signature: (J)Ljava/lang/String; * Signature: (J)Ljava/lang/String;
*/ */
JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGetModelRaw JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGetModelRaw
(JNIEnv * jenv, jclass jcls, jlong jhandle) { (JNIEnv * jenv, jclass jcls, jlong jhandle, jobjectArray jout) {
jstring jresult = 0 ; jint jresult = 0 ;
jstring jinfo = 0;
void *handle = (void *) 0 ; void *handle = (void *) 0 ;
bst_ulong len[1]; bst_ulong len[1];
*len = 0; *len = 0;
char *result = 0 ; char *result[1];
handle = *(void **)&jhandle; handle = *(void **)&jhandle;
result = (char *)XGBoosterGetModelRaw(handle, len); jresult = (jint)XGBoosterGetModelRaw(handle, len, (const char **) result);
if (result) jresult = jenv->NewStringUTF((const char *)result); if (*result){
jinfo = jenv->NewStringUTF((const char *) *result);
jenv->SetObjectArrayElement(jout, 0, jinfo);
}
return jresult; return jresult;
} }
@ -604,15 +647,16 @@ JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGe
* Method: XGBoosterDumpModel * Method: XGBoosterDumpModel
* Signature: (JLjava/lang/String;I)[Ljava/lang/String; * Signature: (JLjava/lang/String;I)[Ljava/lang/String;
*/ */
JNIEXPORT jobjectArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterDumpModel JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterDumpModel
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfmap, jint jwith_stats) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfmap, jint jwith_stats, jobjectArray jout) {
jint jresult = 0;
void *handle = (void *) 0 ; void *handle = (void *) 0 ;
char *fmap = (char *) 0 ; char *fmap = (char *) 0 ;
int with_stats ; int with_stats ;
bst_ulong len[1]; bst_ulong len[1];
*len = 0; *len = 0;
char **result = 0 ; char **result[1];
handle = *(void **)&jhandle; handle = *(void **)&jhandle;
fmap = 0; fmap = 0;
if (jfmap) { if (jfmap) {
@ -621,14 +665,16 @@ JNIEXPORT jobjectArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoos
} }
with_stats = (int)jwith_stats; with_stats = (int)jwith_stats;
result = (char **)XGBoosterDumpModel(handle, (char const *)fmap, with_stats, len); jresult = (jint) XGBoosterDumpModel(handle, (const char *)fmap, with_stats, len, (const char ***) result);
jsize jlen = (jsize)*len; jsize jlen = (jsize)*len;
jobjectArray jresult = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF("")); jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF(""));
for(int i=0 ; i<jlen; i++) { for(int i=0 ; i<jlen; i++) {
jenv->SetObjectArrayElement(jresult, i, jenv->NewStringUTF((const char*)result[i])); jenv->SetObjectArrayElement(jinfos, i, jenv->NewStringUTF((const char*) result[0][i]));
} }
jenv->SetObjectArrayElement(jout, 0, jinfos);
if (fmap) jenv->ReleaseStringUTFChars(jfmap, (const char *)fmap); if (fmap) jenv->ReleaseStringUTFChars(jfmap, (const char *)fmap);
return jresult; return jresult;
} }

View File

@ -9,203 +9,211 @@ extern "C" {
#endif #endif
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixCreateFromFile * Method: XGBGetLastError
* Signature: (Ljava/lang/String;I)J * Signature: ()Ljava/lang/String;
*/ */
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromFile JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBGetLastError
(JNIEnv *, jclass, jstring, jint); (JNIEnv *, jclass);
/*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixCreateFromFile
* Signature: (Ljava/lang/String;I[J)I
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromFile
(JNIEnv *, jclass, jstring, jint, jlongArray);
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixCreateFromCSR * Method: XGDMatrixCreateFromCSR
* Signature: ([J[J[F)J * Signature: ([J[I[F[J)I
*/ */
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSR JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSR
(JNIEnv *, jclass, jlongArray, jintArray, jfloatArray); (JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jlongArray);
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixCreateFromCSC * Method: XGDMatrixCreateFromCSC
* Signature: ([J[J[F)J * Signature: ([J[I[F[J)I
*/ */
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSC JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSC
(JNIEnv *, jclass, jlongArray, jintArray, jfloatArray); (JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jlongArray);
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixCreateFromMat * Method: XGDMatrixCreateFromMat
* Signature: ([FIIF)J * Signature: ([FIIF[J)I
*/ */
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromMat JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromMat
(JNIEnv *, jclass, jfloatArray, jint, jint, jfloat); (JNIEnv *, jclass, jfloatArray, jint, jint, jfloat, jlongArray);
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixSliceDMatrix * Method: XGDMatrixSliceDMatrix
* Signature: (J[I)J * Signature: (J[I[J)I
*/ */
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSliceDMatrix JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSliceDMatrix
(JNIEnv *, jclass, jlong, jintArray); (JNIEnv *, jclass, jlong, jintArray, jlongArray);
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixFree * Method: XGDMatrixFree
* Signature: (J)V * Signature: (J)I
*/ */
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixFree JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixFree
(JNIEnv *, jclass, jlong); (JNIEnv *, jclass, jlong);
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixSaveBinary * Method: XGDMatrixSaveBinary
* Signature: (JLjava/lang/String;I)V * Signature: (JLjava/lang/String;I)I
*/ */
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSaveBinary JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSaveBinary
(JNIEnv *, jclass, jlong, jstring, jint); (JNIEnv *, jclass, jlong, jstring, jint);
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixSetFloatInfo * Method: XGDMatrixSetFloatInfo
* Signature: (JLjava/lang/String;[F)V * Signature: (JLjava/lang/String;[F)I
*/ */
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetFloatInfo JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetFloatInfo
(JNIEnv *, jclass, jlong, jstring, jfloatArray); (JNIEnv *, jclass, jlong, jstring, jfloatArray);
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixSetUIntInfo * Method: XGDMatrixSetUIntInfo
* Signature: (JLjava/lang/String;[I)V * Signature: (JLjava/lang/String;[I)I
*/ */
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetUIntInfo JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetUIntInfo
(JNIEnv *, jclass, jlong, jstring, jintArray); (JNIEnv *, jclass, jlong, jstring, jintArray);
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixSetGroup * Method: XGDMatrixSetGroup
* Signature: (J[I)V * Signature: (J[I)I
*/ */
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGroup JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGroup
(JNIEnv *, jclass, jlong, jintArray); (JNIEnv *, jclass, jlong, jintArray);
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixGetFloatInfo * Method: XGDMatrixGetFloatInfo
* Signature: (JLjava/lang/String;)[F * Signature: (JLjava/lang/String;[[F)I
*/ */
JNIEXPORT jfloatArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetFloatInfo JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetFloatInfo
(JNIEnv *, jclass, jlong, jstring); (JNIEnv *, jclass, jlong, jstring, jobjectArray);
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixGetUIntInfo * Method: XGDMatrixGetUIntInfo
* Signature: (JLjava/lang/String;)[I * Signature: (JLjava/lang/String;[[I)I
*/ */
JNIEXPORT jintArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetUIntInfo JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetUIntInfo
(JNIEnv *, jclass, jlong, jstring); (JNIEnv *, jclass, jlong, jstring, jobjectArray);
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGDMatrixNumRow * Method: XGDMatrixNumRow
* Signature: (J)J * Signature: (J[J)I
*/ */
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixNumRow JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixNumRow
(JNIEnv *, jclass, jlong); (JNIEnv *, jclass, jlong, jlongArray);
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterCreate * Method: XGBoosterCreate
* Signature: ([J)J * Signature: ([J[J)I
*/ */
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCreate JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCreate
(JNIEnv *, jclass, jlongArray); (JNIEnv *, jclass, jlongArray, jlongArray);
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterFree * Method: XGBoosterFree
* Signature: (J)V * Signature: (J)I
*/ */
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterFree JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterFree
(JNIEnv *, jclass, jlong); (JNIEnv *, jclass, jlong);
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterSetParam * Method: XGBoosterSetParam
* Signature: (JLjava/lang/String;Ljava/lang/String;)V * Signature: (JLjava/lang/String;Ljava/lang/String;)I
*/ */
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSetParam JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSetParam
(JNIEnv *, jclass, jlong, jstring, jstring); (JNIEnv *, jclass, jlong, jstring, jstring);
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterUpdateOneIter * Method: XGBoosterUpdateOneIter
* Signature: (JIJ)V * Signature: (JIJ)I
*/ */
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdateOneIter JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdateOneIter
(JNIEnv *, jclass, jlong, jint, jlong); (JNIEnv *, jclass, jlong, jint, jlong);
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterBoostOneIter * Method: XGBoosterBoostOneIter
* Signature: (JJ[F[F)V * Signature: (JJ[F[F)I
*/ */
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoostOneIter JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoostOneIter
(JNIEnv *, jclass, jlong, jlong, jfloatArray, jfloatArray); (JNIEnv *, jclass, jlong, jlong, jfloatArray, jfloatArray);
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterEvalOneIter * Method: XGBoosterEvalOneIter
* Signature: (JI[J[Ljava/lang/String;)Ljava/lang/String; * Signature: (JI[J[Ljava/lang/String;[Ljava/lang/String;)I
*/ */
JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEvalOneIter JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEvalOneIter
(JNIEnv *, jclass, jlong, jint, jlongArray, jobjectArray); (JNIEnv *, jclass, jlong, jint, jlongArray, jobjectArray, jobjectArray);
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterPredict * Method: XGBoosterPredict
* Signature: (JJIJ)[F * Signature: (JJIJ[[F)I
*/ */
JNIEXPORT jfloatArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterPredict JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterPredict
(JNIEnv *, jclass, jlong, jlong, jint, jlong); (JNIEnv *, jclass, jlong, jlong, jint, jlong, jobjectArray);
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterLoadModel * Method: XGBoosterLoadModel
* Signature: (JLjava/lang/String;)V * Signature: (JLjava/lang/String;)I
*/ */
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModel JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModel
(JNIEnv *, jclass, jlong, jstring); (JNIEnv *, jclass, jlong, jstring);
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterSaveModel * Method: XGBoosterSaveModel
* Signature: (JLjava/lang/String;)V * Signature: (JLjava/lang/String;)I
*/ */
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSaveModel JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSaveModel
(JNIEnv *, jclass, jlong, jstring); (JNIEnv *, jclass, jlong, jstring);
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterLoadModelFromBuffer * Method: XGBoosterLoadModelFromBuffer
* Signature: (JJJ)V * Signature: (JJJ)I
*/ */
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModelFromBuffer JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModelFromBuffer
(JNIEnv *, jclass, jlong, jlong, jlong); (JNIEnv *, jclass, jlong, jlong, jlong);
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterGetModelRaw * Method: XGBoosterGetModelRaw
* Signature: (J)Ljava/lang/String; * Signature: (J[Ljava/lang/String;)I
*/ */
JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGetModelRaw JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGetModelRaw
(JNIEnv *, jclass, jlong); (JNIEnv *, jclass, jlong, jobjectArray);
/* /*
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI * Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
* Method: XGBoosterDumpModel * Method: XGBoosterDumpModel
* Signature: (JLjava/lang/String;I)[Ljava/lang/String; * Signature: (JLjava/lang/String;I[[Ljava/lang/String;)I
*/ */
JNIEXPORT jobjectArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterDumpModel JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterDumpModel
(JNIEnv *, jclass, jlong, jstring, jint); (JNIEnv *, jclass, jlong, jstring, jint, jobjectArray);
#ifdef __cplusplus #ifdef __cplusplus
} }