update java wrapper for new fault handle API
This commit is contained in:
parent
7755c00721
commit
f73bcd427d
@ -31,6 +31,7 @@ import org.dmlc.xgboost4j.DMatrix;
|
|||||||
import org.dmlc.xgboost4j.demo.util.DataLoader;
|
import org.dmlc.xgboost4j.demo.util.DataLoader;
|
||||||
import org.dmlc.xgboost4j.demo.util.Params;
|
import org.dmlc.xgboost4j.demo.util.Params;
|
||||||
import org.dmlc.xgboost4j.util.Trainer;
|
import org.dmlc.xgboost4j.util.Trainer;
|
||||||
|
import org.dmlc.xgboost4j.util.XgboostError;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* a simple example of java wrapper for xgboost
|
* a simple example of java wrapper for xgboost
|
||||||
@ -52,7 +53,7 @@ public class BasicWalkThrough {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public static void main(String[] args) throws UnsupportedEncodingException, IOException {
|
public static void main(String[] args) throws UnsupportedEncodingException, IOException, XgboostError {
|
||||||
// load file from text file, also binary buffer generated by xgboost4j
|
// load file from text file, also binary buffer generated by xgboost4j
|
||||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||||
|
|||||||
@ -23,13 +23,14 @@ import org.dmlc.xgboost4j.Booster;
|
|||||||
import org.dmlc.xgboost4j.DMatrix;
|
import org.dmlc.xgboost4j.DMatrix;
|
||||||
import org.dmlc.xgboost4j.demo.util.Params;
|
import org.dmlc.xgboost4j.demo.util.Params;
|
||||||
import org.dmlc.xgboost4j.util.Trainer;
|
import org.dmlc.xgboost4j.util.Trainer;
|
||||||
|
import org.dmlc.xgboost4j.util.XgboostError;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* example for start from a initial base prediction
|
* example for start from a initial base prediction
|
||||||
* @author hzx
|
* @author hzx
|
||||||
*/
|
*/
|
||||||
public class BoostFromPrediction {
|
public class BoostFromPrediction {
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) throws XgboostError {
|
||||||
System.out.println("start running example to start from a initial prediction");
|
System.out.println("start running example to start from a initial prediction");
|
||||||
|
|
||||||
// load file from text file, also binary buffer generated by xgboost4j
|
// load file from text file, also binary buffer generated by xgboost4j
|
||||||
|
|||||||
@ -19,13 +19,14 @@ import java.io.IOException;
|
|||||||
import org.dmlc.xgboost4j.DMatrix;
|
import org.dmlc.xgboost4j.DMatrix;
|
||||||
import org.dmlc.xgboost4j.util.Trainer;
|
import org.dmlc.xgboost4j.util.Trainer;
|
||||||
import org.dmlc.xgboost4j.demo.util.Params;
|
import org.dmlc.xgboost4j.demo.util.Params;
|
||||||
|
import org.dmlc.xgboost4j.util.XgboostError;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* an example of cross validation
|
* an example of cross validation
|
||||||
* @author hzx
|
* @author hzx
|
||||||
*/
|
*/
|
||||||
public class CrossValidation {
|
public class CrossValidation {
|
||||||
public static void main(String[] args) throws IOException {
|
public static void main(String[] args) throws IOException, XgboostError {
|
||||||
//load train mat
|
//load train mat
|
||||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
|
|
||||||
|
|||||||
@ -19,12 +19,15 @@ import java.util.AbstractMap;
|
|||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import org.apache.commons.logging.Log;
|
||||||
|
import org.apache.commons.logging.LogFactory;
|
||||||
import org.dmlc.xgboost4j.Booster;
|
import org.dmlc.xgboost4j.Booster;
|
||||||
import org.dmlc.xgboost4j.IEvaluation;
|
import org.dmlc.xgboost4j.IEvaluation;
|
||||||
import org.dmlc.xgboost4j.DMatrix;
|
import org.dmlc.xgboost4j.DMatrix;
|
||||||
import org.dmlc.xgboost4j.IObjective;
|
import org.dmlc.xgboost4j.IObjective;
|
||||||
import org.dmlc.xgboost4j.demo.util.Params;
|
import org.dmlc.xgboost4j.demo.util.Params;
|
||||||
import org.dmlc.xgboost4j.util.Trainer;
|
import org.dmlc.xgboost4j.util.Trainer;
|
||||||
|
import org.dmlc.xgboost4j.util.XgboostError;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* an example user define objective and eval
|
* an example user define objective and eval
|
||||||
@ -40,6 +43,8 @@ public class CustomObjective {
|
|||||||
* loglikelihoode loss obj function
|
* loglikelihoode loss obj function
|
||||||
*/
|
*/
|
||||||
public static class LogRegObj implements IObjective {
|
public static class LogRegObj implements IObjective {
|
||||||
|
private static final Log logger = LogFactory.getLog(LogRegObj.class);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* simple sigmoid func
|
* simple sigmoid func
|
||||||
* @param input
|
* @param input
|
||||||
@ -66,7 +71,13 @@ public class CustomObjective {
|
|||||||
public List<float[]> getGradient(float[][] predicts, DMatrix dtrain) {
|
public List<float[]> getGradient(float[][] predicts, DMatrix dtrain) {
|
||||||
int nrow = predicts.length;
|
int nrow = predicts.length;
|
||||||
List<float[]> gradients = new ArrayList<>();
|
List<float[]> gradients = new ArrayList<>();
|
||||||
float[] labels = dtrain.getLabel();
|
float[] labels;
|
||||||
|
try {
|
||||||
|
labels = dtrain.getLabel();
|
||||||
|
} catch (XgboostError ex) {
|
||||||
|
logger.error(ex);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
float[] grad = new float[nrow];
|
float[] grad = new float[nrow];
|
||||||
float[] hess = new float[nrow];
|
float[] hess = new float[nrow];
|
||||||
|
|
||||||
@ -93,6 +104,8 @@ public class CustomObjective {
|
|||||||
* Take this in mind when you use the customization, and maybe you need write customized evaluation function
|
* Take this in mind when you use the customization, and maybe you need write customized evaluation function
|
||||||
*/
|
*/
|
||||||
public static class EvalError implements IEvaluation {
|
public static class EvalError implements IEvaluation {
|
||||||
|
private static final Log logger = LogFactory.getLog(EvalError.class);
|
||||||
|
|
||||||
String evalMetric = "custom_error";
|
String evalMetric = "custom_error";
|
||||||
|
|
||||||
public EvalError() {
|
public EvalError() {
|
||||||
@ -106,7 +119,13 @@ public class CustomObjective {
|
|||||||
@Override
|
@Override
|
||||||
public float eval(float[][] predicts, DMatrix dmat) {
|
public float eval(float[][] predicts, DMatrix dmat) {
|
||||||
float error = 0f;
|
float error = 0f;
|
||||||
float[] labels = dmat.getLabel();
|
float[] labels;
|
||||||
|
try {
|
||||||
|
labels = dmat.getLabel();
|
||||||
|
} catch (XgboostError ex) {
|
||||||
|
logger.error(ex);
|
||||||
|
return -1f;
|
||||||
|
}
|
||||||
int nrow = predicts.length;
|
int nrow = predicts.length;
|
||||||
for(int i=0; i<nrow; i++) {
|
for(int i=0; i<nrow; i++) {
|
||||||
if(labels[i]==0f && predicts[i][0]>0) {
|
if(labels[i]==0f && predicts[i][0]>0) {
|
||||||
@ -121,7 +140,7 @@ public class CustomObjective {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) throws XgboostError {
|
||||||
//load train mat (svmlight format)
|
//load train mat (svmlight format)
|
||||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
//load valid mat (svmlight format)
|
//load valid mat (svmlight format)
|
||||||
|
|||||||
@ -23,13 +23,14 @@ import org.dmlc.xgboost4j.Booster;
|
|||||||
import org.dmlc.xgboost4j.DMatrix;
|
import org.dmlc.xgboost4j.DMatrix;
|
||||||
import org.dmlc.xgboost4j.demo.util.Params;
|
import org.dmlc.xgboost4j.demo.util.Params;
|
||||||
import org.dmlc.xgboost4j.util.Trainer;
|
import org.dmlc.xgboost4j.util.Trainer;
|
||||||
|
import org.dmlc.xgboost4j.util.XgboostError;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* simple example for using external memory version
|
* simple example for using external memory version
|
||||||
* @author hzx
|
* @author hzx
|
||||||
*/
|
*/
|
||||||
public class ExternalMemory {
|
public class ExternalMemory {
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) throws XgboostError {
|
||||||
//this is the only difference, add a # followed by a cache prefix name
|
//this is the only difference, add a # followed by a cache prefix name
|
||||||
//several cache file with the prefix will be generated
|
//several cache file with the prefix will be generated
|
||||||
//currently only support convert from libsvm file
|
//currently only support convert from libsvm file
|
||||||
|
|||||||
@ -24,6 +24,7 @@ import org.dmlc.xgboost4j.DMatrix;
|
|||||||
import org.dmlc.xgboost4j.demo.util.CustomEval;
|
import org.dmlc.xgboost4j.demo.util.CustomEval;
|
||||||
import org.dmlc.xgboost4j.demo.util.Params;
|
import org.dmlc.xgboost4j.demo.util.Params;
|
||||||
import org.dmlc.xgboost4j.util.Trainer;
|
import org.dmlc.xgboost4j.util.Trainer;
|
||||||
|
import org.dmlc.xgboost4j.util.XgboostError;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* this is an example of fit generalized linear model in xgboost
|
* this is an example of fit generalized linear model in xgboost
|
||||||
@ -31,7 +32,7 @@ import org.dmlc.xgboost4j.util.Trainer;
|
|||||||
* @author hzx
|
* @author hzx
|
||||||
*/
|
*/
|
||||||
public class GeneralizedLinearModel {
|
public class GeneralizedLinearModel {
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) throws XgboostError {
|
||||||
// load file from text file, also binary buffer generated by xgboost4j
|
// load file from text file, also binary buffer generated by xgboost4j
|
||||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||||
|
|||||||
@ -25,13 +25,14 @@ import org.dmlc.xgboost4j.util.Trainer;
|
|||||||
|
|
||||||
import org.dmlc.xgboost4j.demo.util.CustomEval;
|
import org.dmlc.xgboost4j.demo.util.CustomEval;
|
||||||
import org.dmlc.xgboost4j.demo.util.Params;
|
import org.dmlc.xgboost4j.demo.util.Params;
|
||||||
|
import org.dmlc.xgboost4j.util.XgboostError;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* predict first ntree
|
* predict first ntree
|
||||||
* @author hzx
|
* @author hzx
|
||||||
*/
|
*/
|
||||||
public class PredictFirstNtree {
|
public class PredictFirstNtree {
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) throws XgboostError {
|
||||||
// load file from text file, also binary buffer generated by xgboost4j
|
// load file from text file, also binary buffer generated by xgboost4j
|
||||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||||
|
|||||||
@ -24,13 +24,14 @@ import org.dmlc.xgboost4j.Booster;
|
|||||||
import org.dmlc.xgboost4j.DMatrix;
|
import org.dmlc.xgboost4j.DMatrix;
|
||||||
import org.dmlc.xgboost4j.util.Trainer;
|
import org.dmlc.xgboost4j.util.Trainer;
|
||||||
import org.dmlc.xgboost4j.demo.util.Params;
|
import org.dmlc.xgboost4j.demo.util.Params;
|
||||||
|
import org.dmlc.xgboost4j.util.XgboostError;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* predict leaf indices
|
* predict leaf indices
|
||||||
* @author hzx
|
* @author hzx
|
||||||
*/
|
*/
|
||||||
public class PredictLeafIndices {
|
public class PredictLeafIndices {
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) throws XgboostError {
|
||||||
// load file from text file, also binary buffer generated by xgboost4j
|
// load file from text file, also binary buffer generated by xgboost4j
|
||||||
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
DMatrix trainMat = new DMatrix("../../demo/data/agaricus.txt.train");
|
||||||
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
DMatrix testMat = new DMatrix("../../demo/data/agaricus.txt.test");
|
||||||
|
|||||||
@ -15,14 +15,18 @@
|
|||||||
*/
|
*/
|
||||||
package org.dmlc.xgboost4j.demo.util;
|
package org.dmlc.xgboost4j.demo.util;
|
||||||
|
|
||||||
|
import org.apache.commons.logging.Log;
|
||||||
|
import org.apache.commons.logging.LogFactory;
|
||||||
import org.dmlc.xgboost4j.DMatrix;
|
import org.dmlc.xgboost4j.DMatrix;
|
||||||
import org.dmlc.xgboost4j.IEvaluation;
|
import org.dmlc.xgboost4j.IEvaluation;
|
||||||
|
import org.dmlc.xgboost4j.util.XgboostError;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* a util evaluation class for examples
|
* a util evaluation class for examples
|
||||||
* @author hzx
|
* @author hzx
|
||||||
*/
|
*/
|
||||||
public class CustomEval implements IEvaluation {
|
public class CustomEval implements IEvaluation {
|
||||||
|
private static final Log logger = LogFactory.getLog(CustomEval.class);
|
||||||
|
|
||||||
String evalMetric = "custom_error";
|
String evalMetric = "custom_error";
|
||||||
|
|
||||||
@ -34,7 +38,13 @@ public class CustomEval implements IEvaluation {
|
|||||||
@Override
|
@Override
|
||||||
public float eval(float[][] predicts, DMatrix dmat) {
|
public float eval(float[][] predicts, DMatrix dmat) {
|
||||||
float error = 0f;
|
float error = 0f;
|
||||||
float[] labels = dmat.getLabel();
|
float[] labels;
|
||||||
|
try {
|
||||||
|
labels = dmat.getLabel();
|
||||||
|
} catch (XgboostError ex) {
|
||||||
|
logger.error(ex);
|
||||||
|
return -1f;
|
||||||
|
}
|
||||||
int nrow = predicts.length;
|
int nrow = predicts.length;
|
||||||
for(int i=0; i<nrow; i++) {
|
for(int i=0; i<nrow; i++) {
|
||||||
if(labels[i]==0f && predicts[i][0]>0.5) {
|
if(labels[i]==0f && predicts[i][0]>0.5) {
|
||||||
|
|||||||
@ -77,10 +77,8 @@ public class DataLoader {
|
|||||||
reader.close();
|
reader.close();
|
||||||
in.close();
|
in.close();
|
||||||
|
|
||||||
Float[] flabels = (Float[]) tlabels.toArray();
|
denseData.labels = ArrayUtils.toPrimitive(tlabels.toArray(new Float[tlabels.size()]));
|
||||||
denseData.labels = ArrayUtils.toPrimitive(flabels);
|
denseData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata.size()]));
|
||||||
Float[] fdata = (Float[]) tdata.toArray();
|
|
||||||
denseData.data = ArrayUtils.toPrimitive(fdata);
|
|
||||||
|
|
||||||
return denseData;
|
return denseData;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -30,6 +30,8 @@ import org.apache.commons.logging.Log;
|
|||||||
import org.apache.commons.logging.LogFactory;
|
import org.apache.commons.logging.LogFactory;
|
||||||
|
|
||||||
import org.dmlc.xgboost4j.util.Initializer;
|
import org.dmlc.xgboost4j.util.Initializer;
|
||||||
|
import org.dmlc.xgboost4j.util.ErrorHandle;
|
||||||
|
import org.dmlc.xgboost4j.util.XgboostError;
|
||||||
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
|
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
|
||||||
|
|
||||||
|
|
||||||
@ -57,8 +59,9 @@ public final class Booster {
|
|||||||
* init Booster from dMatrixs
|
* init Booster from dMatrixs
|
||||||
* @param params parameters
|
* @param params parameters
|
||||||
* @param dMatrixs DMatrix array
|
* @param dMatrixs DMatrix array
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public Booster(Iterable<Entry<String, Object>> params, DMatrix[] dMatrixs) {
|
public Booster(Iterable<Entry<String, Object>> params, DMatrix[] dMatrixs) throws XgboostError {
|
||||||
init(dMatrixs);
|
init(dMatrixs);
|
||||||
setParam("seed","0");
|
setParam("seed","0");
|
||||||
setParams(params);
|
setParams(params);
|
||||||
@ -70,9 +73,11 @@ public final class Booster {
|
|||||||
* load model from modelPath
|
* load model from modelPath
|
||||||
* @param params parameters
|
* @param params parameters
|
||||||
* @param modelPath booster modelPath (model generated by booster.saveModel)
|
* @param modelPath booster modelPath (model generated by booster.saveModel)
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public Booster(Iterable<Entry<String, Object>> params, String modelPath) {
|
public Booster(Iterable<Entry<String, Object>> params, String modelPath) throws XgboostError {
|
||||||
handle = XgboostJNI.XGBoosterCreate(new long[] {});
|
long[] out = new long[1];
|
||||||
|
init(null);
|
||||||
loadModel(modelPath);
|
loadModel(modelPath);
|
||||||
setParam("seed","0");
|
setParam("seed","0");
|
||||||
setParams(params);
|
setParams(params);
|
||||||
@ -81,28 +86,33 @@ public final class Booster {
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
private void init(DMatrix[] dMatrixs) {
|
private void init(DMatrix[] dMatrixs) throws XgboostError {
|
||||||
long[] handles = null;
|
long[] handles = null;
|
||||||
if(dMatrixs != null) {
|
if(dMatrixs != null) {
|
||||||
handles = dMatrixs2handles(dMatrixs);
|
handles = dMatrixs2handles(dMatrixs);
|
||||||
}
|
}
|
||||||
handle = XgboostJNI.XGBoosterCreate(handles);
|
long[] out = new long[1];
|
||||||
|
ErrorHandle.checkCall(XgboostJNI.XGBoosterCreate(handles, out));
|
||||||
|
|
||||||
|
handle = out[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* set parameter
|
* set parameter
|
||||||
* @param key param name
|
* @param key param name
|
||||||
* @param value param value
|
* @param value param value
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public final void setParam(String key, String value) {
|
public final void setParam(String key, String value) throws XgboostError {
|
||||||
XgboostJNI.XGBoosterSetParam(handle, key, value);
|
ErrorHandle.checkCall(XgboostJNI.XGBoosterSetParam(handle, key, value));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* set parameters
|
* set parameters
|
||||||
* @param params parameters key-value map
|
* @param params parameters key-value map
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public void setParams(Iterable<Entry<String, Object>> params) {
|
public void setParams(Iterable<Entry<String, Object>> params) throws XgboostError {
|
||||||
if(params!=null) {
|
if(params!=null) {
|
||||||
for(Map.Entry<String, Object> entry : params) {
|
for(Map.Entry<String, Object> entry : params) {
|
||||||
setParam(entry.getKey(), entry.getValue().toString());
|
setParam(entry.getKey(), entry.getValue().toString());
|
||||||
@ -115,9 +125,10 @@ public final class Booster {
|
|||||||
* Update (one iteration)
|
* Update (one iteration)
|
||||||
* @param dtrain training data
|
* @param dtrain training data
|
||||||
* @param iter current iteration number
|
* @param iter current iteration number
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public void update(DMatrix dtrain, int iter) {
|
public void update(DMatrix dtrain, int iter) throws XgboostError {
|
||||||
XgboostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle());
|
ErrorHandle.checkCall(XgboostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -125,8 +136,9 @@ public final class Booster {
|
|||||||
* @param dtrain training data
|
* @param dtrain training data
|
||||||
* @param iter current iteration number
|
* @param iter current iteration number
|
||||||
* @param obj customized objective class
|
* @param obj customized objective class
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public void update(DMatrix dtrain, int iter, IObjective obj) {
|
public void update(DMatrix dtrain, int iter, IObjective obj) throws XgboostError {
|
||||||
float[][] predicts = predict(dtrain, true);
|
float[][] predicts = predict(dtrain, true);
|
||||||
List<float[]> gradients = obj.getGradient(predicts, dtrain);
|
List<float[]> gradients = obj.getGradient(predicts, dtrain);
|
||||||
boost(dtrain, gradients.get(0), gradients.get(1));
|
boost(dtrain, gradients.get(0), gradients.get(1));
|
||||||
@ -137,12 +149,13 @@ public final class Booster {
|
|||||||
* @param dtrain training data
|
* @param dtrain training data
|
||||||
* @param grad first order of gradient
|
* @param grad first order of gradient
|
||||||
* @param hess seconde order of gradient
|
* @param hess seconde order of gradient
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public void boost(DMatrix dtrain, float[] grad, float[] hess) {
|
public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XgboostError {
|
||||||
if(grad.length != hess.length) {
|
if(grad.length != hess.length) {
|
||||||
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length, hess.length));
|
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length, hess.length));
|
||||||
}
|
}
|
||||||
XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad, hess);
|
ErrorHandle.checkCall(XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad, hess));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -151,11 +164,13 @@ public final class Booster {
|
|||||||
* @param evalNames name for eval dmatrixs, used for check results
|
* @param evalNames name for eval dmatrixs, used for check results
|
||||||
* @param iter current eval iteration
|
* @param iter current eval iteration
|
||||||
* @return eval information
|
* @return eval information
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) {
|
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XgboostError {
|
||||||
long[] handles = dMatrixs2handles(evalMatrixs);
|
long[] handles = dMatrixs2handles(evalMatrixs);
|
||||||
String evalInfo = XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames);
|
String[] evalInfo = new String[1];
|
||||||
return evalInfo;
|
ErrorHandle.checkCall(XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames, evalInfo));
|
||||||
|
return evalInfo[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -165,8 +180,9 @@ public final class Booster {
|
|||||||
* @param iter
|
* @param iter
|
||||||
* @param eval
|
* @param eval
|
||||||
* @return eval information
|
* @return eval information
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter, IEvaluation eval) {
|
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter, IEvaluation eval) throws XgboostError {
|
||||||
String evalInfo = "";
|
String evalInfo = "";
|
||||||
for(int i=0; i<evalNames.length; i++) {
|
for(int i=0; i<evalNames.length; i++) {
|
||||||
String evalName = evalNames[i];
|
String evalName = evalNames[i];
|
||||||
@ -184,10 +200,12 @@ public final class Booster {
|
|||||||
* @param evalNames name for eval dmatrixs, used for check results
|
* @param evalNames name for eval dmatrixs, used for check results
|
||||||
* @param iter current eval iteration
|
* @param iter current eval iteration
|
||||||
* @return eval information
|
* @return eval information
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public String evalSet(long[] dHandles, String[] evalNames, int iter) {
|
public String evalSet(long[] dHandles, String[] evalNames, int iter) throws XgboostError {
|
||||||
String evalInfo = XgboostJNI.XGBoosterEvalOneIter(handle, iter, dHandles, evalNames);
|
String[] evalInfo = new String[1];
|
||||||
return evalInfo;
|
ErrorHandle.checkCall(XgboostJNI.XGBoosterEvalOneIter(handle, iter, dHandles, evalNames, evalInfo));
|
||||||
|
return evalInfo[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -197,8 +215,9 @@ public final class Booster {
|
|||||||
* @param evalName
|
* @param evalName
|
||||||
* @param iter
|
* @param iter
|
||||||
* @return eval information
|
* @return eval information
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public String eval(DMatrix evalMat, String evalName, int iter) {
|
public String eval(DMatrix evalMat, String evalName, int iter) throws XgboostError {
|
||||||
DMatrix[] evalMats = new DMatrix[] {evalMat};
|
DMatrix[] evalMats = new DMatrix[] {evalMat};
|
||||||
String[] evalNames = new String[] {evalName};
|
String[] evalNames = new String[] {evalName};
|
||||||
return evalSet(evalMats, evalNames, iter);
|
return evalSet(evalMats, evalNames, iter);
|
||||||
@ -212,7 +231,7 @@ public final class Booster {
|
|||||||
* @param predLeaf
|
* @param predLeaf
|
||||||
* @return predict results
|
* @return predict results
|
||||||
*/
|
*/
|
||||||
private synchronized float[][] pred(DMatrix data, boolean outPutMargin, long treeLimit, boolean predLeaf) {
|
private synchronized float[][] pred(DMatrix data, boolean outPutMargin, long treeLimit, boolean predLeaf) throws XgboostError {
|
||||||
int optionMask = 0;
|
int optionMask = 0;
|
||||||
if(outPutMargin) {
|
if(outPutMargin) {
|
||||||
optionMask = 1;
|
optionMask = 1;
|
||||||
@ -220,15 +239,16 @@ public final class Booster {
|
|||||||
if(predLeaf) {
|
if(predLeaf) {
|
||||||
optionMask = 2;
|
optionMask = 2;
|
||||||
}
|
}
|
||||||
float[] rawPredicts = XgboostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask, treeLimit);
|
float[][] rawPredicts = new float[1][];
|
||||||
|
ErrorHandle.checkCall(XgboostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask, treeLimit, rawPredicts));
|
||||||
int row = (int) data.rowNum();
|
int row = (int) data.rowNum();
|
||||||
int col = (int) rawPredicts.length/row;
|
int col = (int) rawPredicts[0].length/row;
|
||||||
float[][] predicts = new float[row][col];
|
float[][] predicts = new float[row][col];
|
||||||
int r,c;
|
int r,c;
|
||||||
for(int i=0; i< rawPredicts.length; i++) {
|
for(int i=0; i< rawPredicts[0].length; i++) {
|
||||||
r = i/col;
|
r = i/col;
|
||||||
c = i%col;
|
c = i%col;
|
||||||
predicts[r][c] = rawPredicts[i];
|
predicts[r][c] = rawPredicts[0][i];
|
||||||
}
|
}
|
||||||
return predicts;
|
return predicts;
|
||||||
}
|
}
|
||||||
@ -237,8 +257,9 @@ public final class Booster {
|
|||||||
* Predict with data
|
* Predict with data
|
||||||
* @param data dmatrix storing the input
|
* @param data dmatrix storing the input
|
||||||
* @return predict result
|
* @return predict result
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public float[][] predict(DMatrix data) {
|
public float[][] predict(DMatrix data) throws XgboostError {
|
||||||
return pred(data, false, 0, false);
|
return pred(data, false, 0, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -247,8 +268,9 @@ public final class Booster {
|
|||||||
* @param data dmatrix storing the input
|
* @param data dmatrix storing the input
|
||||||
* @param outPutMargin Whether to output the raw untransformed margin value.
|
* @param outPutMargin Whether to output the raw untransformed margin value.
|
||||||
* @return predict result
|
* @return predict result
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public float[][] predict(DMatrix data, boolean outPutMargin) {
|
public float[][] predict(DMatrix data, boolean outPutMargin) throws XgboostError {
|
||||||
return pred(data, outPutMargin, 0, false);
|
return pred(data, outPutMargin, 0, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -258,8 +280,9 @@ public final class Booster {
|
|||||||
* @param outPutMargin Whether to output the raw untransformed margin value.
|
* @param outPutMargin Whether to output the raw untransformed margin value.
|
||||||
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
||||||
* @return predict result
|
* @return predict result
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public float[][] predict(DMatrix data, boolean outPutMargin, long treeLimit) {
|
public float[][] predict(DMatrix data, boolean outPutMargin, long treeLimit) throws XgboostError {
|
||||||
return pred(data, outPutMargin, treeLimit, false);
|
return pred(data, outPutMargin, treeLimit, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -272,8 +295,9 @@ public final class Booster {
|
|||||||
Note that the leaf index of a tree is unique per tree, so you may find leaf 1
|
Note that the leaf index of a tree is unique per tree, so you may find leaf 1
|
||||||
in both tree 1 and tree 0.
|
in both tree 1 and tree 0.
|
||||||
* @return predict result
|
* @return predict result
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public float[][] predict(DMatrix data , long treeLimit, boolean predLeaf) {
|
public float[][] predict(DMatrix data , long treeLimit, boolean predLeaf) throws XgboostError {
|
||||||
return pred(data, false, treeLimit, predLeaf);
|
return pred(data, false, treeLimit, predLeaf);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -293,14 +317,16 @@ public final class Booster {
|
|||||||
* get the dump of the model as a string array
|
* get the dump of the model as a string array
|
||||||
* @param withStats Controls whether the split statistics are output.
|
* @param withStats Controls whether the split statistics are output.
|
||||||
* @return dumped model information
|
* @return dumped model information
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public String[] getDumpInfo(boolean withStats) {
|
public String[] getDumpInfo(boolean withStats) throws XgboostError {
|
||||||
int statsFlag = 0;
|
int statsFlag = 0;
|
||||||
if(withStats) {
|
if(withStats) {
|
||||||
statsFlag = 1;
|
statsFlag = 1;
|
||||||
}
|
}
|
||||||
String[] modelInfos = XgboostJNI.XGBoosterDumpModel(handle, "", statsFlag);
|
String[][] modelInfos = new String[1][];
|
||||||
return modelInfos;
|
ErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos));
|
||||||
|
return modelInfos[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -308,14 +334,16 @@ public final class Booster {
|
|||||||
* @param featureMap featureMap file
|
* @param featureMap featureMap file
|
||||||
* @param withStats Controls whether the split statistics are output.
|
* @param withStats Controls whether the split statistics are output.
|
||||||
* @return dumped model information
|
* @return dumped model information
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public String[] getDumpInfo(String featureMap, boolean withStats) {
|
public String[] getDumpInfo(String featureMap, boolean withStats) throws XgboostError {
|
||||||
int statsFlag = 0;
|
int statsFlag = 0;
|
||||||
if(withStats) {
|
if(withStats) {
|
||||||
statsFlag = 1;
|
statsFlag = 1;
|
||||||
}
|
}
|
||||||
String[] modelInfos = XgboostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag);
|
String[][] modelInfos = new String[1][];
|
||||||
return modelInfos;
|
ErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag, modelInfos));
|
||||||
|
return modelInfos[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -326,8 +354,9 @@ public final class Booster {
|
|||||||
* @throws FileNotFoundException
|
* @throws FileNotFoundException
|
||||||
* @throws UnsupportedEncodingException
|
* @throws UnsupportedEncodingException
|
||||||
* @throws IOException
|
* @throws IOException
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public void dumpModel(String modelPath, boolean withStats) throws FileNotFoundException, UnsupportedEncodingException, IOException {
|
public void dumpModel(String modelPath, boolean withStats) throws FileNotFoundException, UnsupportedEncodingException, IOException, XgboostError {
|
||||||
File tf = new File(modelPath);
|
File tf = new File(modelPath);
|
||||||
FileOutputStream out = new FileOutputStream(tf);
|
FileOutputStream out = new FileOutputStream(tf);
|
||||||
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
|
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
|
||||||
@ -352,8 +381,9 @@ public final class Booster {
|
|||||||
* @throws FileNotFoundException
|
* @throws FileNotFoundException
|
||||||
* @throws UnsupportedEncodingException
|
* @throws UnsupportedEncodingException
|
||||||
* @throws IOException
|
* @throws IOException
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public void dumpModel(String modelPath, String featureMap, boolean withStats) throws FileNotFoundException, UnsupportedEncodingException, IOException {
|
public void dumpModel(String modelPath, String featureMap, boolean withStats) throws FileNotFoundException, UnsupportedEncodingException, IOException, XgboostError {
|
||||||
File tf = new File(modelPath);
|
File tf = new File(modelPath);
|
||||||
FileOutputStream out = new FileOutputStream(tf);
|
FileOutputStream out = new FileOutputStream(tf);
|
||||||
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
|
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
|
||||||
@ -372,8 +402,9 @@ public final class Booster {
|
|||||||
/**
|
/**
|
||||||
* get importance of each feature
|
* get importance of each feature
|
||||||
* @return featureMap key: feature index, value: feature importance score
|
* @return featureMap key: feature index, value: feature importance score
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public Map<String, Integer> getFeatureScore() {
|
public Map<String, Integer> getFeatureScore() throws XgboostError {
|
||||||
String[] modelInfos = getDumpInfo(false);
|
String[] modelInfos = getDumpInfo(false);
|
||||||
Map<String, Integer> featureScore = new HashMap<>();
|
Map<String, Integer> featureScore = new HashMap<>();
|
||||||
for(String tree : modelInfos) {
|
for(String tree : modelInfos) {
|
||||||
@ -400,8 +431,9 @@ public final class Booster {
|
|||||||
* get importance of each feature
|
* get importance of each feature
|
||||||
* @param featureMap file to save dumped model info
|
* @param featureMap file to save dumped model info
|
||||||
* @return featureMap key: feature index, value: feature importance score
|
* @return featureMap key: feature index, value: feature importance score
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public Map<String, Integer> getFeatureScore(String featureMap) {
|
public Map<String, Integer> getFeatureScore(String featureMap) throws XgboostError {
|
||||||
String[] modelInfos = getDumpInfo(featureMap, false);
|
String[] modelInfos = getDumpInfo(featureMap, false);
|
||||||
Map<String, Integer> featureScore = new HashMap<>();
|
Map<String, Integer> featureScore = new HashMap<>();
|
||||||
for(String tree : modelInfos) {
|
for(String tree : modelInfos) {
|
||||||
|
|||||||
@ -18,6 +18,8 @@ package org.dmlc.xgboost4j;
|
|||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import org.apache.commons.logging.Log;
|
import org.apache.commons.logging.Log;
|
||||||
import org.apache.commons.logging.LogFactory;
|
import org.apache.commons.logging.LogFactory;
|
||||||
|
import org.dmlc.xgboost4j.util.ErrorHandle;
|
||||||
|
import org.dmlc.xgboost4j.util.XgboostError;
|
||||||
import org.dmlc.xgboost4j.util.Initializer;
|
import org.dmlc.xgboost4j.util.Initializer;
|
||||||
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
|
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
|
||||||
|
|
||||||
@ -50,9 +52,12 @@ public class DMatrix {
|
|||||||
/**
|
/**
|
||||||
* init DMatrix from file (svmlight format)
|
* init DMatrix from file (svmlight format)
|
||||||
* @param dataPath
|
* @param dataPath
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public DMatrix(String dataPath) {
|
public DMatrix(String dataPath) throws XgboostError {
|
||||||
handle = XgboostJNI.XGDMatrixCreateFromFile(dataPath, 1);
|
long[] out = new long[1];
|
||||||
|
ErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromFile(dataPath, 1, out));
|
||||||
|
handle = out[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -61,17 +66,20 @@ public class DMatrix {
|
|||||||
* @param indices Indices (colIndexs for CSR or rowIndexs for CSC)
|
* @param indices Indices (colIndexs for CSR or rowIndexs for CSC)
|
||||||
* @param data non zero values (sequence by row for CSR or by col for CSC)
|
* @param data non zero values (sequence by row for CSR or by col for CSC)
|
||||||
* @param st sparse matrix type (CSR or CSC)
|
* @param st sparse matrix type (CSR or CSC)
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) {
|
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XgboostError {
|
||||||
|
long[] out = new long[1];
|
||||||
if(st == SparseType.CSR) {
|
if(st == SparseType.CSR) {
|
||||||
handle = XgboostJNI.XGDMatrixCreateFromCSR(headers, indices, data);
|
ErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromCSR(headers, indices, data, out));
|
||||||
}
|
}
|
||||||
else if(st == SparseType.CSC) {
|
else if(st == SparseType.CSC) {
|
||||||
handle = XgboostJNI.XGDMatrixCreateFromCSC(headers, indices, data);
|
ErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromCSC(headers, indices, data, out));
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
throw new UnknownError("unknow sparsetype");
|
throw new UnknownError("unknow sparsetype");
|
||||||
}
|
}
|
||||||
|
handle = out[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -79,9 +87,12 @@ public class DMatrix {
|
|||||||
* @param data data values
|
* @param data data values
|
||||||
* @param nrow number of rows
|
* @param nrow number of rows
|
||||||
* @param ncol number of columns
|
* @param ncol number of columns
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public DMatrix(float[] data, int nrow, int ncol) {
|
public DMatrix(float[] data, int nrow, int ncol) throws XgboostError {
|
||||||
handle = XgboostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f);
|
long[] out = new long[1];
|
||||||
|
ErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f, out));
|
||||||
|
handle = out[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -98,33 +109,36 @@ public class DMatrix {
|
|||||||
* set label of dmatrix
|
* set label of dmatrix
|
||||||
* @param labels
|
* @param labels
|
||||||
*/
|
*/
|
||||||
public void setLabel(float[] labels) {
|
public void setLabel(float[] labels) throws XgboostError {
|
||||||
XgboostJNI.XGDMatrixSetFloatInfo(handle, "label", labels);
|
ErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "label", labels));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* set weight of each instance
|
* set weight of each instance
|
||||||
* @param weights
|
* @param weights
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public void setWeight(float[] weights) {
|
public void setWeight(float[] weights) throws XgboostError {
|
||||||
XgboostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights);
|
ErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* if specified, xgboost will start from this init margin
|
* if specified, xgboost will start from this init margin
|
||||||
* can be used to specify initial prediction to boost from
|
* can be used to specify initial prediction to boost from
|
||||||
* @param baseMargin
|
* @param baseMargin
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public void setBaseMargin(float[] baseMargin) {
|
public void setBaseMargin(float[] baseMargin) throws XgboostError {
|
||||||
XgboostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin);
|
ErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* if specified, xgboost will start from this init margin
|
* if specified, xgboost will start from this init margin
|
||||||
* can be used to specify initial prediction to boost from
|
* can be used to specify initial prediction to boost from
|
||||||
* @param baseMargin
|
* @param baseMargin
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public void setBaseMargin(float[][] baseMargin) {
|
public void setBaseMargin(float[][] baseMargin) throws XgboostError {
|
||||||
float[] flattenMargin = flatten(baseMargin);
|
float[] flattenMargin = flatten(baseMargin);
|
||||||
setBaseMargin(flattenMargin);
|
setBaseMargin(flattenMargin);
|
||||||
}
|
}
|
||||||
@ -132,42 +146,48 @@ public class DMatrix {
|
|||||||
/**
|
/**
|
||||||
* Set group sizes of DMatrix (used for ranking)
|
* Set group sizes of DMatrix (used for ranking)
|
||||||
* @param group
|
* @param group
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public void setGroup(int[] group) {
|
public void setGroup(int[] group) throws XgboostError {
|
||||||
XgboostJNI.XGDMatrixSetGroup(handle, group);
|
ErrorHandle.checkCall(XgboostJNI.XGDMatrixSetGroup(handle, group));
|
||||||
}
|
}
|
||||||
|
|
||||||
private float[] getFloatInfo(String field) {
|
private float[] getFloatInfo(String field) throws XgboostError {
|
||||||
float[] infos = XgboostJNI.XGDMatrixGetFloatInfo(handle, field);
|
float[][] infos = new float[1][];
|
||||||
return infos;
|
ErrorHandle.checkCall(XgboostJNI.XGDMatrixGetFloatInfo(handle, field, infos));
|
||||||
|
return infos[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
private int[] getIntInfo(String field) {
|
private int[] getIntInfo(String field) throws XgboostError {
|
||||||
int[] infos = XgboostJNI.XGDMatrixGetUIntInfo(handle, field);
|
int[][] infos = new int[1][];
|
||||||
return infos;
|
ErrorHandle.checkCall(XgboostJNI.XGDMatrixGetUIntInfo(handle, field, infos));
|
||||||
|
return infos[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* get label values
|
* get label values
|
||||||
* @return label
|
* @return label
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public float[] getLabel() {
|
public float[] getLabel() throws XgboostError {
|
||||||
return getFloatInfo("label");
|
return getFloatInfo("label");
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* get weight of the DMatrix
|
* get weight of the DMatrix
|
||||||
* @return weights
|
* @return weights
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public float[] getWeight() {
|
public float[] getWeight() throws XgboostError {
|
||||||
return getFloatInfo("weight");
|
return getFloatInfo("weight");
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* get base margin of the DMatrix
|
* get base margin of the DMatrix
|
||||||
* @return base margin
|
* @return base margin
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public float[] getBaseMargin() {
|
public float[] getBaseMargin() throws XgboostError {
|
||||||
return getFloatInfo("base_margin");
|
return getFloatInfo("base_margin");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -175,9 +195,12 @@ public class DMatrix {
|
|||||||
* Slice the DMatrix and return a new DMatrix that only contains `rowIndex`.
|
* Slice the DMatrix and return a new DMatrix that only contains `rowIndex`.
|
||||||
* @param rowIndex
|
* @param rowIndex
|
||||||
* @return sliced new DMatrix
|
* @return sliced new DMatrix
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public DMatrix slice(int[] rowIndex) {
|
public DMatrix slice(int[] rowIndex) throws XgboostError {
|
||||||
long sHandle = XgboostJNI.XGDMatrixSliceDMatrix(handle, rowIndex);
|
long[] out = new long[1];
|
||||||
|
ErrorHandle.checkCall(XgboostJNI.XGDMatrixSliceDMatrix(handle, rowIndex, out));
|
||||||
|
long sHandle = out[0];
|
||||||
DMatrix sMatrix = new DMatrix(sHandle);
|
DMatrix sMatrix = new DMatrix(sHandle);
|
||||||
return sMatrix;
|
return sMatrix;
|
||||||
}
|
}
|
||||||
@ -185,9 +208,12 @@ public class DMatrix {
|
|||||||
/**
|
/**
|
||||||
* get the row number of DMatrix
|
* get the row number of DMatrix
|
||||||
* @return number of rows
|
* @return number of rows
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public long rowNum() {
|
public long rowNum() throws XgboostError {
|
||||||
return XgboostJNI.XGDMatrixNumRow(handle);
|
long[] rowNum = new long[1];
|
||||||
|
ErrorHandle.checkCall(XgboostJNI.XGDMatrixNumRow(handle,rowNum));
|
||||||
|
return rowNum[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -37,8 +37,9 @@ public class CVPack {
|
|||||||
* @param dtrain train data
|
* @param dtrain train data
|
||||||
* @param dtest test data
|
* @param dtest test data
|
||||||
* @param params parameters
|
* @param params parameters
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public CVPack(DMatrix dtrain, DMatrix dtest, Iterable<Map.Entry<String, Object>> params) {
|
public CVPack(DMatrix dtrain, DMatrix dtest, Iterable<Map.Entry<String, Object>> params) throws XgboostError {
|
||||||
dmats = new DMatrix[] {dtrain, dtest};
|
dmats = new DMatrix[] {dtrain, dtest};
|
||||||
booster = new Booster(params, dmats);
|
booster = new Booster(params, dmats);
|
||||||
names = new String[] {"train", "test"};
|
names = new String[] {"train", "test"};
|
||||||
@ -49,8 +50,9 @@ public class CVPack {
|
|||||||
/**
|
/**
|
||||||
* update one iteration
|
* update one iteration
|
||||||
* @param iter iteration num
|
* @param iter iteration num
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public void update(int iter) {
|
public void update(int iter) throws XgboostError {
|
||||||
booster.update(dtrain, iter);
|
booster.update(dtrain, iter);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -58,8 +60,9 @@ public class CVPack {
|
|||||||
* update one iteration
|
* update one iteration
|
||||||
* @param iter iteration num
|
* @param iter iteration num
|
||||||
* @param obj customized objective
|
* @param obj customized objective
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public void update(int iter, IObjective obj) {
|
public void update(int iter, IObjective obj) throws XgboostError {
|
||||||
booster.update(dtrain, iter, obj);
|
booster.update(dtrain, iter, obj);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -67,8 +70,9 @@ public class CVPack {
|
|||||||
* evaluation
|
* evaluation
|
||||||
* @param iter iteration num
|
* @param iter iteration num
|
||||||
* @return
|
* @return
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public String eval(int iter) {
|
public String eval(int iter) throws XgboostError {
|
||||||
return booster.evalSet(dmats, names, iter);
|
return booster.evalSet(dmats, names, iter);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -77,8 +81,9 @@ public class CVPack {
|
|||||||
* @param iter iteration num
|
* @param iter iteration num
|
||||||
* @param eval customized eval
|
* @param eval customized eval
|
||||||
* @return
|
* @return
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
*/
|
*/
|
||||||
public String eval(int iter, IEvaluation eval) {
|
public String eval(int iter, IEvaluation eval) throws XgboostError {
|
||||||
return booster.evalSet(dmats, names, iter, eval);
|
return booster.evalSet(dmats, names, iter, eval);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -0,0 +1,50 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.dmlc.xgboost4j.util;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import org.apache.commons.logging.Log;
|
||||||
|
import org.apache.commons.logging.LogFactory;
|
||||||
|
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* error handle for Xgboost
|
||||||
|
* @author hzx
|
||||||
|
*/
|
||||||
|
public class ErrorHandle {
|
||||||
|
private static final Log logger = LogFactory.getLog(ErrorHandle.class);
|
||||||
|
|
||||||
|
//load native library
|
||||||
|
static {
|
||||||
|
try {
|
||||||
|
Initializer.InitXgboost();
|
||||||
|
} catch (IOException ex) {
|
||||||
|
logger.error("load native library failed.");
|
||||||
|
logger.error(ex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* check the return value of C API
|
||||||
|
* @param ret return valud of xgboostJNI C API call
|
||||||
|
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||||
|
*/
|
||||||
|
public static void checkCall(int ret) throws XgboostError {
|
||||||
|
if(ret != 0) {
|
||||||
|
throw new XgboostError(XgboostJNI.XGBGetLastError());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -47,7 +47,7 @@ public class Trainer {
|
|||||||
* @return trained booster
|
* @return trained booster
|
||||||
*/
|
*/
|
||||||
public static Booster train(Iterable<Entry<String, Object>> params, DMatrix dtrain, int round,
|
public static Booster train(Iterable<Entry<String, Object>> params, DMatrix dtrain, int round,
|
||||||
Iterable<Entry<String, DMatrix>> watchs, IObjective obj, IEvaluation eval) {
|
Iterable<Entry<String, DMatrix>> watchs, IObjective obj, IEvaluation eval) throws XgboostError {
|
||||||
|
|
||||||
//collect eval matrixs
|
//collect eval matrixs
|
||||||
String[] evalNames;
|
String[] evalNames;
|
||||||
@ -112,7 +112,7 @@ public class Trainer {
|
|||||||
* @param eval customized evaluation (set to null if not used)
|
* @param eval customized evaluation (set to null if not used)
|
||||||
* @return evaluation history
|
* @return evaluation history
|
||||||
*/
|
*/
|
||||||
public static String[] crossValiation(Iterable<Entry<String, Object>> params, DMatrix data, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) {
|
public static String[] crossValiation(Iterable<Entry<String, Object>> params, DMatrix data, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) throws XgboostError {
|
||||||
CVPack[] cvPacks = makeNFold(data, nfold, params, metrics);
|
CVPack[] cvPacks = makeNFold(data, nfold, params, metrics);
|
||||||
String[] evalHist = new String[round];
|
String[] evalHist = new String[round];
|
||||||
String[] results = new String[cvPacks.length];
|
String[] results = new String[cvPacks.length];
|
||||||
@ -149,7 +149,7 @@ public class Trainer {
|
|||||||
* @param evalMetrics Evaluation metrics
|
* @param evalMetrics Evaluation metrics
|
||||||
* @return CV package array
|
* @return CV package array
|
||||||
*/
|
*/
|
||||||
public static CVPack[] makeNFold(DMatrix data, int nfold, Iterable<Entry<String, Object>> params, String[] evalMetrics) {
|
public static CVPack[] makeNFold(DMatrix data, int nfold, Iterable<Entry<String, Object>> params, String[] evalMetrics) throws XgboostError {
|
||||||
List<Integer> samples = genRandPermutationNums(0, (int) data.rowNum());
|
List<Integer> samples = genRandPermutationNums(0, (int) data.rowNum());
|
||||||
int step = samples.size()/nfold;
|
int step = samples.size()/nfold;
|
||||||
int[] testSlice = new int[step];
|
int[] testSlice = new int[step];
|
||||||
|
|||||||
@ -0,0 +1,26 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2014 by Contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.dmlc.xgboost4j.util;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* custom error class for xgboost
|
||||||
|
* @author hzx
|
||||||
|
*/
|
||||||
|
public class XgboostError extends Exception{
|
||||||
|
public XgboostError(String message) {
|
||||||
|
super(message);
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -17,32 +17,34 @@ package org.dmlc.xgboost4j.wrapper;
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* xgboost jni wrapper functions for xgboost_wrapper.h
|
* xgboost jni wrapper functions for xgboost_wrapper.h
|
||||||
|
* change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster
|
||||||
* @author hzx
|
* @author hzx
|
||||||
*/
|
*/
|
||||||
public class XgboostJNI {
|
public class XgboostJNI {
|
||||||
public final static native long XGDMatrixCreateFromFile(String fname, int silent);
|
public final static native String XGBGetLastError();
|
||||||
public final static native long XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data);
|
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
|
||||||
public final static native long XGDMatrixCreateFromCSC(long[] colptr, int[] indices, float[] data);
|
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data, long[] out);
|
||||||
public final static native long XGDMatrixCreateFromMat(float[] data, int nrow, int ncol, float missing);
|
public final static native int XGDMatrixCreateFromCSC(long[] colptr, int[] indices, float[] data, long[] out);
|
||||||
public final static native long XGDMatrixSliceDMatrix(long handle, int[] idxset);
|
public final static native int XGDMatrixCreateFromMat(float[] data, int nrow, int ncol, float missing, long[] out);
|
||||||
public final static native void XGDMatrixFree(long handle);
|
public final static native int XGDMatrixSliceDMatrix(long handle, int[] idxset, long[] out);
|
||||||
public final static native void XGDMatrixSaveBinary(long handle, String fname, int silent);
|
public final static native int XGDMatrixFree(long handle);
|
||||||
public final static native void XGDMatrixSetFloatInfo(long handle, String field, float[] array);
|
public final static native int XGDMatrixSaveBinary(long handle, String fname, int silent);
|
||||||
public final static native void XGDMatrixSetUIntInfo(long handle, String field, int[] array);
|
public final static native int XGDMatrixSetFloatInfo(long handle, String field, float[] array);
|
||||||
public final static native void XGDMatrixSetGroup(long handle, int[] group);
|
public final static native int XGDMatrixSetUIntInfo(long handle, String field, int[] array);
|
||||||
public final static native float[] XGDMatrixGetFloatInfo(long handle, String field);
|
public final static native int XGDMatrixSetGroup(long handle, int[] group);
|
||||||
public final static native int[] XGDMatrixGetUIntInfo(long handle, String filed);
|
public final static native int XGDMatrixGetFloatInfo(long handle, String field, float[][] info);
|
||||||
public final static native long XGDMatrixNumRow(long handle);
|
public final static native int XGDMatrixGetUIntInfo(long handle, String filed, int[][] info);
|
||||||
public final static native long XGBoosterCreate(long[] handles);
|
public final static native int XGDMatrixNumRow(long handle, long[] row);
|
||||||
public final static native void XGBoosterFree(long handle);
|
public final static native int XGBoosterCreate(long[] handles, long[] out);
|
||||||
public final static native void XGBoosterSetParam(long handle, String name, String value);
|
public final static native int XGBoosterFree(long handle);
|
||||||
public final static native void XGBoosterUpdateOneIter(long handle, int iter, long dtrain);
|
public final static native int XGBoosterSetParam(long handle, String name, String value);
|
||||||
public final static native void XGBoosterBoostOneIter(long handle, long dtrain, float[] grad, float[] hess);
|
public final static native int XGBoosterUpdateOneIter(long handle, int iter, long dtrain);
|
||||||
public final static native String XGBoosterEvalOneIter(long handle, int iter, long[] dmats, String[] evnames);
|
public final static native int XGBoosterBoostOneIter(long handle, long dtrain, float[] grad, float[] hess);
|
||||||
public final static native float[] XGBoosterPredict(long handle, long dmat, int option_mask, long ntree_limit);
|
public final static native int XGBoosterEvalOneIter(long handle, int iter, long[] dmats, String[] evnames, String[] eval_info);
|
||||||
public final static native void XGBoosterLoadModel(long handle, String fname);
|
public final static native int XGBoosterPredict(long handle, long dmat, int option_mask, long ntree_limit, float[][] predicts);
|
||||||
public final static native void XGBoosterSaveModel(long handle, String fname);
|
public final static native int XGBoosterLoadModel(long handle, String fname);
|
||||||
public final static native void XGBoosterLoadModelFromBuffer(long handle, long buf, long len);
|
public final static native int XGBoosterSaveModel(long handle, String fname);
|
||||||
public final static native String XGBoosterGetModelRaw(long handle);
|
public final static native int XGBoosterLoadModelFromBuffer(long handle, long buf, long len);
|
||||||
public final static native String[] XGBoosterDumpModel(long handle, String fmap, int with_stats);
|
public final static native int XGBoosterGetModelRaw(long handle, String[] out_string);
|
||||||
|
public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats, String[][] out_strings);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -16,21 +16,34 @@
|
|||||||
#include "../wrapper/xgboost_wrapper.h"
|
#include "../wrapper/xgboost_wrapper.h"
|
||||||
#include "xgboost4j_wrapper.h"
|
#include "xgboost4j_wrapper.h"
|
||||||
|
|
||||||
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromFile
|
JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBGetLastError
|
||||||
(JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent) {
|
(JNIEnv *jenv, jclass jcls) {
|
||||||
jlong jresult = 0 ;
|
jstring jresult = 0 ;
|
||||||
|
char* result = 0;
|
||||||
|
result = (char *)XGBGetLastError();
|
||||||
|
if (result) jresult = jenv->NewStringUTF((const char *)result);
|
||||||
|
return jresult;
|
||||||
|
}
|
||||||
|
|
||||||
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromFile
|
||||||
|
(JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent, jlongArray jout) {
|
||||||
|
jint jresult = 0 ;
|
||||||
char *fname = (char *) 0 ;
|
char *fname = (char *) 0 ;
|
||||||
int silent;
|
int silent;
|
||||||
void *result = 0 ;
|
void* result[1];
|
||||||
fname = 0;
|
unsigned long out[1];
|
||||||
if (jfname) {
|
|
||||||
fname = (char *)jenv->GetStringUTFChars(jfname, 0);
|
fname = (char *)jenv->GetStringUTFChars(jfname, 0);
|
||||||
if (!fname) return 0;
|
|
||||||
}
|
|
||||||
silent = (int)jsilent;
|
silent = (int)jsilent;
|
||||||
result = (void *)XGDMatrixCreateFromFile((char const *)fname, silent);
|
jresult = (jint) XGDMatrixCreateFromFile((char const *)fname, silent, result);
|
||||||
*(void **)&jresult = result;
|
|
||||||
|
|
||||||
|
*(void **)&out[0] = *result;
|
||||||
|
|
||||||
if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname);
|
if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname);
|
||||||
|
|
||||||
|
jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) out);
|
||||||
return jresult;
|
return jresult;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -39,12 +52,13 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCrea
|
|||||||
* Method: XGDMatrixCreateFromCSR
|
* Method: XGDMatrixCreateFromCSR
|
||||||
* Signature: ([J[J[F)J
|
* Signature: ([J[J[F)J
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSR
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSR
|
||||||
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata) {
|
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) {
|
||||||
jlong jresult = 0 ;
|
jint jresult = 0 ;
|
||||||
bst_ulong nindptr ;
|
bst_ulong nindptr ;
|
||||||
bst_ulong nelem;
|
bst_ulong nelem;
|
||||||
void *result = 0 ;
|
void *result[1];
|
||||||
|
unsigned long out[1];
|
||||||
|
|
||||||
jlong* indptr = jenv->GetLongArrayElements(jindptr, 0);
|
jlong* indptr = jenv->GetLongArrayElements(jindptr, 0);
|
||||||
jint* indices = jenv->GetIntArrayElements(jindices, 0);
|
jint* indices = jenv->GetIntArrayElements(jindices, 0);
|
||||||
@ -52,8 +66,9 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCrea
|
|||||||
nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
|
nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
|
||||||
nelem = (bst_ulong)jenv->GetArrayLength(jdata);
|
nelem = (bst_ulong)jenv->GetArrayLength(jdata);
|
||||||
|
|
||||||
result = (void *)XGDMatrixCreateFromCSR((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem);
|
jresult = (jint) XGDMatrixCreateFromCSR((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, result);
|
||||||
*(void **)&jresult = result;
|
*(void **)&out[0] = *result;
|
||||||
|
jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) out);
|
||||||
|
|
||||||
//release
|
//release
|
||||||
jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
|
jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
|
||||||
@ -68,12 +83,13 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCrea
|
|||||||
* Method: XGDMatrixCreateFromCSC
|
* Method: XGDMatrixCreateFromCSC
|
||||||
* Signature: ([J[J[F)J
|
* Signature: ([J[J[F)J
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSC
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSC
|
||||||
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata) {
|
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) {
|
||||||
jlong jresult = 0 ;
|
jint jresult = 0;
|
||||||
bst_ulong nindptr ;
|
bst_ulong nindptr ;
|
||||||
bst_ulong nelem;
|
bst_ulong nelem;
|
||||||
void *result = 0 ;
|
void *result[1];
|
||||||
|
unsigned long out[1];
|
||||||
|
|
||||||
jlong* indptr = jenv->GetLongArrayElements(jindptr, NULL);
|
jlong* indptr = jenv->GetLongArrayElements(jindptr, NULL);
|
||||||
jint* indices = jenv->GetIntArrayElements(jindices, 0);
|
jint* indices = jenv->GetIntArrayElements(jindices, 0);
|
||||||
@ -81,8 +97,9 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCrea
|
|||||||
nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
|
nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
|
||||||
nelem = (bst_ulong)jenv->GetArrayLength(jdata);
|
nelem = (bst_ulong)jenv->GetArrayLength(jdata);
|
||||||
|
|
||||||
result = (void *)XGDMatrixCreateFromCSC((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem);
|
jresult = (jint) XGDMatrixCreateFromCSC((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, result);
|
||||||
*(void **)&jresult = result;
|
*(void **)&out[0] = *result;
|
||||||
|
jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) out);
|
||||||
|
|
||||||
//release
|
//release
|
||||||
jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
|
jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
|
||||||
@ -97,21 +114,24 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCrea
|
|||||||
* Method: XGDMatrixCreateFromMat
|
* Method: XGDMatrixCreateFromMat
|
||||||
* Signature: ([FIIF)J
|
* Signature: ([FIIF)J
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromMat
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromMat
|
||||||
(JNIEnv *jenv, jclass jcls, jfloatArray jdata, jint jnrow, jint jncol, jfloat jmiss) {
|
(JNIEnv *jenv, jclass jcls, jfloatArray jdata, jint jnrow, jint jncol, jfloat jmiss, jlongArray jout) {
|
||||||
jlong jresult = 0 ;
|
jint jresult = 0 ;
|
||||||
bst_ulong nrow ;
|
bst_ulong nrow ;
|
||||||
bst_ulong ncol ;
|
bst_ulong ncol ;
|
||||||
float miss ;
|
float miss ;
|
||||||
void *result = 0 ;
|
void *result[1];
|
||||||
|
unsigned long out[1];
|
||||||
|
|
||||||
|
|
||||||
jfloat* data = jenv->GetFloatArrayElements(jdata, 0);
|
jfloat* data = jenv->GetFloatArrayElements(jdata, 0);
|
||||||
nrow = (bst_ulong)jnrow;
|
nrow = (bst_ulong)jnrow;
|
||||||
ncol = (bst_ulong)jncol;
|
ncol = (bst_ulong)jncol;
|
||||||
miss = (float)jmiss;
|
miss = (float)jmiss;
|
||||||
result = (void *)XGDMatrixCreateFromMat((float const *)data, nrow, ncol, miss);
|
|
||||||
*(void **)&jresult = result;
|
jresult = (jint) XGDMatrixCreateFromMat((float const *)data, nrow, ncol, miss, result);
|
||||||
|
*(void **)&out[0] = *result;
|
||||||
|
jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) out);
|
||||||
|
|
||||||
//release
|
//release
|
||||||
jenv->ReleaseFloatArrayElements(jdata, data, 0);
|
jenv->ReleaseFloatArrayElements(jdata, data, 0);
|
||||||
@ -124,19 +144,21 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCrea
|
|||||||
* Method: XGDMatrixSliceDMatrix
|
* Method: XGDMatrixSliceDMatrix
|
||||||
* Signature: (J[I)J
|
* Signature: (J[I)J
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSliceDMatrix
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSliceDMatrix
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jintArray jindexset) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jintArray jindexset, jlongArray jout) {
|
||||||
jlong jresult = 0 ;
|
jint jresult = 0 ;
|
||||||
void *handle = (void *) 0 ;
|
void *handle = (void *) 0 ;
|
||||||
bst_ulong len;
|
bst_ulong len;
|
||||||
void *result = 0 ;
|
void *result[1];
|
||||||
|
unsigned long out[1];
|
||||||
|
|
||||||
jint* indexset = jenv->GetIntArrayElements(jindexset, 0);
|
jint* indexset = jenv->GetIntArrayElements(jindexset, 0);
|
||||||
handle = *(void **)&jhandle;
|
handle = *(void **)&jhandle;
|
||||||
len = (bst_ulong)jenv->GetArrayLength(jindexset);
|
len = (bst_ulong)jenv->GetArrayLength(jindexset);
|
||||||
|
|
||||||
result = (void *)XGDMatrixSliceDMatrix(handle, (int const *)indexset, len);
|
jresult = (jint) XGDMatrixSliceDMatrix(handle, (int const *)indexset, len, result);
|
||||||
*(void **)&jresult = result;
|
*(void **)&out[0] = *result;
|
||||||
|
jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) out);
|
||||||
|
|
||||||
//release
|
//release
|
||||||
jenv->ReleaseIntArrayElements(jindexset, indexset, 0);
|
jenv->ReleaseIntArrayElements(jindexset, indexset, 0);
|
||||||
@ -149,11 +171,13 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSlic
|
|||||||
* Method: XGDMatrixFree
|
* Method: XGDMatrixFree
|
||||||
* Signature: (J)V
|
* Signature: (J)V
|
||||||
*/
|
*/
|
||||||
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixFree
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixFree
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
|
||||||
|
jint jresult = 0;
|
||||||
void *handle = (void *) 0 ;
|
void *handle = (void *) 0 ;
|
||||||
handle = *(void **)&jhandle;
|
handle = *(void **)&jhandle;
|
||||||
XGDMatrixFree(handle);
|
jresult = (jint) XGDMatrixFree(handle);
|
||||||
|
return jresult;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -161,20 +185,21 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixFree
|
|||||||
* Method: XGDMatrixSaveBinary
|
* Method: XGDMatrixSaveBinary
|
||||||
* Signature: (JLjava/lang/String;I)V
|
* Signature: (JLjava/lang/String;I)V
|
||||||
*/
|
*/
|
||||||
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSaveBinary
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSaveBinary
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname, jint jsilent) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname, jint jsilent) {
|
||||||
|
jint jresult = 0;
|
||||||
void *handle = (void *) 0 ;
|
void *handle = (void *) 0 ;
|
||||||
char *fname = (char *) 0 ;
|
char *fname = (char *) 0 ;
|
||||||
int silent ;
|
int silent ;
|
||||||
handle = *(void **)&jhandle;
|
handle = *(void **)&jhandle;
|
||||||
fname = 0;
|
fname = 0;
|
||||||
if (jfname) {
|
|
||||||
fname = (char *)jenv->GetStringUTFChars(jfname, 0);
|
fname = (char *)jenv->GetStringUTFChars(jfname, 0);
|
||||||
if (!fname) return ;
|
|
||||||
}
|
|
||||||
silent = (int)jsilent;
|
silent = (int)jsilent;
|
||||||
XGDMatrixSaveBinary(handle, (char const *)fname, silent);
|
jresult = (jint) XGDMatrixSaveBinary(handle, (char const *)fname, silent);
|
||||||
if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname);
|
if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname);
|
||||||
|
|
||||||
|
return jresult;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -182,27 +207,28 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSaveB
|
|||||||
* Method: XGDMatrixSetFloatInfo
|
* Method: XGDMatrixSetFloatInfo
|
||||||
* Signature: (JLjava/lang/String;[F)V
|
* Signature: (JLjava/lang/String;[F)V
|
||||||
*/
|
*/
|
||||||
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetFloatInfo
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetFloatInfo
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jfloatArray jarray) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jfloatArray jarray) {
|
||||||
|
jint jresult = 0;
|
||||||
void *handle = (void *) 0 ;
|
void *handle = (void *) 0 ;
|
||||||
char *field = (char *) 0 ;
|
char *field = (char *) 0 ;
|
||||||
bst_ulong len;
|
bst_ulong len;
|
||||||
|
|
||||||
|
|
||||||
handle = *(void **)&jhandle;
|
handle = *(void **)&jhandle;
|
||||||
field = 0;
|
|
||||||
if (jfield) {
|
|
||||||
field = (char *)jenv->GetStringUTFChars(jfield, 0);
|
field = (char *)jenv->GetStringUTFChars(jfield, 0);
|
||||||
if (!field) return ;
|
|
||||||
}
|
|
||||||
|
|
||||||
jfloat* array = jenv->GetFloatArrayElements(jarray, NULL);
|
jfloat* array = jenv->GetFloatArrayElements(jarray, NULL);
|
||||||
len = (bst_ulong)jenv->GetArrayLength(jarray);
|
len = (bst_ulong)jenv->GetArrayLength(jarray);
|
||||||
XGDMatrixSetFloatInfo(handle, (char const *)field, (float const *)array, len);
|
jresult = (jint) XGDMatrixSetFloatInfo(handle, (char const *)field, (float const *)array, len);
|
||||||
|
|
||||||
//release
|
//release
|
||||||
if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field);
|
if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field);
|
||||||
jenv->ReleaseFloatArrayElements(jarray, array, 0);
|
jenv->ReleaseFloatArrayElements(jarray, array, 0);
|
||||||
|
|
||||||
|
return jresult;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -210,25 +236,26 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetFl
|
|||||||
* Method: XGDMatrixSetUIntInfo
|
* Method: XGDMatrixSetUIntInfo
|
||||||
* Signature: (JLjava/lang/String;[I)V
|
* Signature: (JLjava/lang/String;[I)V
|
||||||
*/
|
*/
|
||||||
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetUIntInfo
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetUIntInfo
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jintArray jarray) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jintArray jarray) {
|
||||||
|
jint jresult = 0;
|
||||||
void *handle = (void *) 0 ;
|
void *handle = (void *) 0 ;
|
||||||
char *field = (char *) 0 ;
|
char *field = (char *) 0 ;
|
||||||
bst_ulong len ;
|
bst_ulong len ;
|
||||||
handle = *(void **)&jhandle;
|
handle = *(void **)&jhandle;
|
||||||
field = 0;
|
field = 0;
|
||||||
if (jfield) {
|
|
||||||
field = (char *)jenv->GetStringUTFChars(jfield, 0);
|
field = (char *)jenv->GetStringUTFChars(jfield, 0);
|
||||||
if (!field) return ;
|
|
||||||
}
|
|
||||||
|
|
||||||
jint* array = jenv->GetIntArrayElements(jarray, NULL);
|
jint* array = jenv->GetIntArrayElements(jarray, NULL);
|
||||||
len = (bst_ulong)jenv->GetArrayLength(jarray);
|
len = (bst_ulong)jenv->GetArrayLength(jarray);
|
||||||
|
|
||||||
XGDMatrixSetUIntInfo(handle, (char const *)field, (unsigned int const *)array, len);
|
jresult = (jint) XGDMatrixSetUIntInfo(handle, (char const *)field, (unsigned int const *)array, len);
|
||||||
//release
|
//release
|
||||||
if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field);
|
if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field);
|
||||||
jenv->ReleaseIntArrayElements(jarray, array, 0);
|
jenv->ReleaseIntArrayElements(jarray, array, 0);
|
||||||
|
|
||||||
|
return jresult;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -236,8 +263,9 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetUI
|
|||||||
* Method: XGDMatrixSetGroup
|
* Method: XGDMatrixSetGroup
|
||||||
* Signature: (J[I)V
|
* Signature: (J[I)V
|
||||||
*/
|
*/
|
||||||
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGroup
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGroup
|
||||||
(JNIEnv * jenv, jclass jcls, jlong jhandle, jintArray jarray) {
|
(JNIEnv * jenv, jclass jcls, jlong jhandle, jintArray jarray) {
|
||||||
|
jint jresult = 0;
|
||||||
void *handle = (void *) 0 ;
|
void *handle = (void *) 0 ;
|
||||||
bst_ulong len ;
|
bst_ulong len ;
|
||||||
|
|
||||||
@ -245,11 +273,12 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGr
|
|||||||
jint* array = jenv->GetIntArrayElements(jarray, NULL);
|
jint* array = jenv->GetIntArrayElements(jarray, NULL);
|
||||||
len = (bst_ulong)jenv->GetArrayLength(jarray);
|
len = (bst_ulong)jenv->GetArrayLength(jarray);
|
||||||
|
|
||||||
XGDMatrixSetGroup(handle, (unsigned int const *)array, len);
|
jresult = (jint) XGDMatrixSetGroup(handle, (unsigned int const *)array, len);
|
||||||
|
|
||||||
//release
|
//release
|
||||||
jenv->ReleaseIntArrayElements(jarray, array, 0);
|
jenv->ReleaseIntArrayElements(jarray, array, 0);
|
||||||
|
|
||||||
|
return jresult;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -257,13 +286,14 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGr
|
|||||||
* Method: XGDMatrixGetFloatInfo
|
* Method: XGDMatrixGetFloatInfo
|
||||||
* Signature: (JLjava/lang/String;)[F
|
* Signature: (JLjava/lang/String;)[F
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jfloatArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetFloatInfo
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetFloatInfo
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) {
|
||||||
|
jint jresult = 0;
|
||||||
void *handle = (void *) 0 ;
|
void *handle = (void *) 0 ;
|
||||||
char *field = (char *) 0 ;
|
char *field = (char *) 0 ;
|
||||||
bst_ulong len[1];
|
bst_ulong len[1];
|
||||||
*len = 0;
|
*len = 0;
|
||||||
float *result = 0 ;
|
float *result[1];
|
||||||
|
|
||||||
handle = *(void **)&jhandle;
|
handle = *(void **)&jhandle;
|
||||||
field = 0;
|
field = 0;
|
||||||
@ -272,13 +302,15 @@ JNIEXPORT jfloatArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatr
|
|||||||
if (!field) return 0;
|
if (!field) return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
result = (float *)XGDMatrixGetFloatInfo((void const *)handle, (char const *)field, len);
|
jresult = (jint) XGDMatrixGetFloatInfo(handle, (char const *)field, len, (const float **) result);
|
||||||
|
|
||||||
if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field);
|
if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field);
|
||||||
|
|
||||||
jsize jlen = (jsize)*len;
|
jsize jlen = (jsize)*len;
|
||||||
jfloatArray jresult = jenv->NewFloatArray(jlen);
|
jfloatArray jarray = jenv->NewFloatArray(jlen);
|
||||||
jenv->SetFloatArrayRegion(jresult, 0, jlen, (jfloat *)result);
|
jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) *result);
|
||||||
|
jenv->SetObjectArrayElement(jout, 0, (jobject) jarray);
|
||||||
|
|
||||||
return jresult;
|
return jresult;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -287,28 +319,26 @@ JNIEXPORT jfloatArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatr
|
|||||||
* Method: XGDMatrixGetUIntInfo
|
* Method: XGDMatrixGetUIntInfo
|
||||||
* Signature: (JLjava/lang/String;)[I
|
* Signature: (JLjava/lang/String;)[I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jintArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetUIntInfo
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetUIntInfo
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) {
|
||||||
|
jint jresult = 0;
|
||||||
void *handle = (void *) 0 ;
|
void *handle = (void *) 0 ;
|
||||||
char *field = (char *) 0 ;
|
char *field = (char *) 0 ;
|
||||||
bst_ulong len[1];
|
bst_ulong len[1];
|
||||||
*len = 0;
|
*len = 0;
|
||||||
unsigned int *result = 0 ;
|
unsigned int *result[1];
|
||||||
|
|
||||||
handle = *(void **)&jhandle;
|
handle = *(void **)&jhandle;
|
||||||
field = 0;
|
|
||||||
if (jfield) {
|
|
||||||
field = (char *)jenv->GetStringUTFChars(jfield, 0);
|
field = (char *)jenv->GetStringUTFChars(jfield, 0);
|
||||||
if (!field) return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
result = (unsigned int *)XGDMatrixGetUIntInfo((void const *)handle, (char const *)field, len);
|
jresult = (jint) XGDMatrixGetUIntInfo(handle, (char const *)field, len, (const unsigned int **) result);
|
||||||
|
|
||||||
if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field);
|
if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field);
|
||||||
|
|
||||||
jsize jlen = (jsize)*len;
|
jsize jlen = (jsize)*len;
|
||||||
jintArray jresult = jenv->NewIntArray(jlen);
|
jintArray jarray = jenv->NewIntArray(jlen);
|
||||||
jenv->SetIntArrayRegion(jresult, 0, jlen, (jint *)result);
|
jenv->SetIntArrayRegion(jarray, 0, jlen, (jint *) *result);
|
||||||
|
jenv->SetObjectArrayElement(jout, 0, jarray);
|
||||||
return jresult;
|
return jresult;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -317,14 +347,14 @@ JNIEXPORT jintArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrix
|
|||||||
* Method: XGDMatrixNumRow
|
* Method: XGDMatrixNumRow
|
||||||
* Signature: (J)J
|
* Signature: (J)J
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixNumRow
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixNumRow
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlongArray jout) {
|
||||||
jlong jresult = 0 ;
|
jint jresult = 0 ;
|
||||||
void *handle = (void *) 0 ;
|
void *handle = (void *) 0 ;
|
||||||
bst_ulong result;
|
bst_ulong result[1];
|
||||||
handle = *(void **)&jhandle;
|
handle = *(void **)&jhandle;
|
||||||
result = (bst_ulong)XGDMatrixNumRow((void const *)handle);
|
jresult = (jint) XGDMatrixNumRow(handle, result);
|
||||||
jresult = (jlong)result;
|
jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) result);
|
||||||
return jresult;
|
return jresult;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -333,13 +363,14 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixNumR
|
|||||||
* Method: XGBoosterCreate
|
* Method: XGBoosterCreate
|
||||||
* Signature: ([J)J
|
* Signature: ([J)J
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCreate
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCreate
|
||||||
(JNIEnv *jenv, jclass jcls, jlongArray jhandles) {
|
(JNIEnv *jenv, jclass jcls, jlongArray jhandles, jlongArray jout) {
|
||||||
jlong jresult = 0 ;
|
jint jresult = 0;
|
||||||
void **handles = 0;
|
void **handles = 0;
|
||||||
bst_ulong len = 0;
|
bst_ulong len = 0;
|
||||||
void *result = 0 ;
|
void *result[1];
|
||||||
jlong* cjhandles = 0;
|
jlong* cjhandles = 0;
|
||||||
|
unsigned long out[1];
|
||||||
|
|
||||||
if(jhandles) {
|
if(jhandles) {
|
||||||
len = (bst_ulong)jenv->GetArrayLength(jhandles);
|
len = (bst_ulong)jenv->GetArrayLength(jhandles);
|
||||||
@ -351,7 +382,7 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCrea
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
result = (void *)XGBoosterCreate(handles, len);
|
jresult = (jint) XGBoosterCreate(handles, len, result);
|
||||||
|
|
||||||
//release
|
//release
|
||||||
if(jhandles) {
|
if(jhandles) {
|
||||||
@ -359,7 +390,9 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCrea
|
|||||||
jenv->ReleaseLongArrayElements(jhandles, cjhandles, 0);
|
jenv->ReleaseLongArrayElements(jhandles, cjhandles, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
*(void **)&jresult = result;
|
*(void **)&out[0] = *result;
|
||||||
|
jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) out);
|
||||||
|
|
||||||
return jresult;
|
return jresult;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -368,11 +401,11 @@ JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCrea
|
|||||||
* Method: XGBoosterFree
|
* Method: XGBoosterFree
|
||||||
* Signature: (J)V
|
* Signature: (J)V
|
||||||
*/
|
*/
|
||||||
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterFree
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterFree
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
|
||||||
void *handle = (void *) 0 ;
|
void *handle = (void *) 0 ;
|
||||||
handle = *(void **)&jhandle;
|
handle = *(void **)&jhandle;
|
||||||
XGBoosterFree(handle);
|
return (jint) XGBoosterFree(handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -381,27 +414,22 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterFree
|
|||||||
* Method: XGBoosterSetParam
|
* Method: XGBoosterSetParam
|
||||||
* Signature: (JLjava/lang/String;Ljava/lang/String;)V
|
* Signature: (JLjava/lang/String;Ljava/lang/String;)V
|
||||||
*/
|
*/
|
||||||
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSetParam
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSetParam
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jname, jstring jvalue) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jname, jstring jvalue) {
|
||||||
|
jint jresult = -1;
|
||||||
void *handle = (void *) 0 ;
|
void *handle = (void *) 0 ;
|
||||||
char *name = (char *) 0 ;
|
char *name = (char *) 0 ;
|
||||||
char *value = (char *) 0 ;
|
char *value = (char *) 0 ;
|
||||||
handle = *(void **)&jhandle;
|
handle = *(void **)&jhandle;
|
||||||
|
|
||||||
name = 0;
|
|
||||||
if (jname) {
|
|
||||||
name = (char *)jenv->GetStringUTFChars(jname, 0);
|
name = (char *)jenv->GetStringUTFChars(jname, 0);
|
||||||
if (!name) return ;
|
|
||||||
}
|
|
||||||
|
|
||||||
value = 0;
|
|
||||||
if (jvalue) {
|
|
||||||
value = (char *)jenv->GetStringUTFChars(jvalue, 0);
|
value = (char *)jenv->GetStringUTFChars(jvalue, 0);
|
||||||
if (!value) return ;
|
|
||||||
}
|
jresult = (jint) XGBoosterSetParam(handle, (char const *)name, (char const *)value);
|
||||||
XGBoosterSetParam(handle, (char const *)name, (char const *)value);
|
|
||||||
if (name) jenv->ReleaseStringUTFChars(jname, (const char *)name);
|
if (name) jenv->ReleaseStringUTFChars(jname, (const char *)name);
|
||||||
if (value) jenv->ReleaseStringUTFChars(jvalue, (const char *)value);
|
if (value) jenv->ReleaseStringUTFChars(jvalue, (const char *)value);
|
||||||
|
|
||||||
|
return jresult;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -409,7 +437,7 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSetPa
|
|||||||
* Method: XGBoosterUpdateOneIter
|
* Method: XGBoosterUpdateOneIter
|
||||||
* Signature: (JIJ)V
|
* Signature: (JIJ)V
|
||||||
*/
|
*/
|
||||||
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdateOneIter
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdateOneIter
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlong jdtrain) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlong jdtrain) {
|
||||||
void *handle = (void *) 0 ;
|
void *handle = (void *) 0 ;
|
||||||
int iter ;
|
int iter ;
|
||||||
@ -417,7 +445,7 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdat
|
|||||||
handle = *(void **)&jhandle;
|
handle = *(void **)&jhandle;
|
||||||
iter = (int)jiter;
|
iter = (int)jiter;
|
||||||
dtrain = *(void **)&jdtrain;
|
dtrain = *(void **)&jdtrain;
|
||||||
XGBoosterUpdateOneIter(handle, iter, dtrain);
|
return (jint) XGBoosterUpdateOneIter(handle, iter, dtrain);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -425,8 +453,9 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdat
|
|||||||
* Method: XGBoosterBoostOneIter
|
* Method: XGBoosterBoostOneIter
|
||||||
* Signature: (JJ[F[F)V
|
* Signature: (JJ[F[F)V
|
||||||
*/
|
*/
|
||||||
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoostOneIter
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoostOneIter
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jfloatArray jgrad, jfloatArray jhess) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jfloatArray jgrad, jfloatArray jhess) {
|
||||||
|
jint jresult = 0;
|
||||||
void *handle = (void *) 0 ;
|
void *handle = (void *) 0 ;
|
||||||
void *dtrain = (void *) 0 ;
|
void *dtrain = (void *) 0 ;
|
||||||
bst_ulong len ;
|
bst_ulong len ;
|
||||||
@ -436,11 +465,13 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoost
|
|||||||
jfloat* grad = jenv->GetFloatArrayElements(jgrad, 0);
|
jfloat* grad = jenv->GetFloatArrayElements(jgrad, 0);
|
||||||
jfloat* hess = jenv->GetFloatArrayElements(jhess, 0);
|
jfloat* hess = jenv->GetFloatArrayElements(jhess, 0);
|
||||||
len = (bst_ulong)jenv->GetArrayLength(jgrad);
|
len = (bst_ulong)jenv->GetArrayLength(jgrad);
|
||||||
XGBoosterBoostOneIter(handle, dtrain, grad, hess, len);
|
jresult = (jint) XGBoosterBoostOneIter(handle, dtrain, grad, hess, len);
|
||||||
|
|
||||||
//release
|
//release
|
||||||
jenv->ReleaseFloatArrayElements(jgrad, grad, 0);
|
jenv->ReleaseFloatArrayElements(jgrad, grad, 0);
|
||||||
jenv->ReleaseFloatArrayElements(jhess, hess, 0);
|
jenv->ReleaseFloatArrayElements(jhess, hess, 0);
|
||||||
|
|
||||||
|
return jresult;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -448,15 +479,15 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoost
|
|||||||
* Method: XGBoosterEvalOneIter
|
* Method: XGBoosterEvalOneIter
|
||||||
* Signature: (JI[J[Ljava/lang/String;)Ljava/lang/String;
|
* Signature: (JI[J[Ljava/lang/String;)Ljava/lang/String;
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEvalOneIter
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEvalOneIter
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlongArray jdmats, jobjectArray jevnames) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlongArray jdmats, jobjectArray jevnames, jobjectArray jout) {
|
||||||
jstring jresult = 0 ;
|
jint jresult = 0 ;
|
||||||
void *handle = (void *) 0 ;
|
void *handle = (void *) 0 ;
|
||||||
int iter ;
|
int iter ;
|
||||||
void **dmats = 0;
|
void **dmats = 0;
|
||||||
char **evnames = 0;
|
char **evnames = 0;
|
||||||
bst_ulong len ;
|
bst_ulong len ;
|
||||||
char *result = 0 ;
|
char *result[1];
|
||||||
|
|
||||||
handle = *(void **)&jhandle;
|
handle = *(void **)&jhandle;
|
||||||
iter = (int)jiter;
|
iter = (int)jiter;
|
||||||
@ -480,7 +511,7 @@ JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEv
|
|||||||
evnames[i] = (char *)jenv->GetStringUTFChars(jevname, 0);
|
evnames[i] = (char *)jenv->GetStringUTFChars(jevname, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
result = (char *)XGBoosterEvalOneIter(handle, iter, dmats, (char const *(*))evnames, len);
|
jresult = (jint) XGBoosterEvalOneIter(handle, iter, dmats, (char const *(*))evnames, len, (const char **) result);
|
||||||
|
|
||||||
if(len > 0) {
|
if(len > 0) {
|
||||||
delete[] dmats;
|
delete[] dmats;
|
||||||
@ -493,7 +524,9 @@ JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEv
|
|||||||
jenv->ReleaseLongArrayElements(jdmats, cjdmats, 0);
|
jenv->ReleaseLongArrayElements(jdmats, cjdmats, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (result) jresult = jenv->NewStringUTF((const char *)result);
|
jstring jinfo = 0;
|
||||||
|
if (*result) jinfo = jenv->NewStringUTF((const char *) *result);
|
||||||
|
jenv->SetObjectArrayElement(jout, 0, jinfo);
|
||||||
|
|
||||||
return jresult;
|
return jresult;
|
||||||
}
|
}
|
||||||
@ -503,26 +536,29 @@ JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEv
|
|||||||
* Method: XGBoosterPredict
|
* Method: XGBoosterPredict
|
||||||
* Signature: (JJIJ)[F
|
* Signature: (JJIJ)[F
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jfloatArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterPredict
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterPredict
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdmat, jint joption_mask, jlong jntree_limit) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdmat, jint joption_mask, jlong jntree_limit, jobjectArray jout) {
|
||||||
|
jint jresult = 0;
|
||||||
void *handle = (void *) 0 ;
|
void *handle = (void *) 0 ;
|
||||||
void *dmat = (void *) 0 ;
|
void *dmat = (void *) 0 ;
|
||||||
int option_mask ;
|
int option_mask ;
|
||||||
unsigned int ntree_limit ;
|
unsigned int ntree_limit ;
|
||||||
bst_ulong len[1];
|
bst_ulong len[1];
|
||||||
*len = 0;
|
*len = 0;
|
||||||
float *result = 0 ;
|
float *result[1];
|
||||||
|
|
||||||
handle = *(void **)&jhandle;
|
handle = *(void **)&jhandle;
|
||||||
dmat = *(void **)&jdmat;
|
dmat = *(void **)&jdmat;
|
||||||
option_mask = (int)joption_mask;
|
option_mask = (int)joption_mask;
|
||||||
ntree_limit = (unsigned int)jntree_limit;
|
ntree_limit = (unsigned int)jntree_limit;
|
||||||
|
|
||||||
result = (float *)XGBoosterPredict(handle, dmat, option_mask, ntree_limit, len);
|
jresult = (jint) XGBoosterPredict(handle, dmat, option_mask, ntree_limit, len, (const float **) result);
|
||||||
|
|
||||||
jsize jlen = (jsize)*len;
|
jsize jlen = (jsize)*len;
|
||||||
jfloatArray jresult = jenv->NewFloatArray(jlen);
|
jfloatArray jarray = jenv->NewFloatArray(jlen);
|
||||||
jenv->SetFloatArrayRegion(jresult, 0, jlen, (jfloat *)result);
|
jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) *result);
|
||||||
|
jenv->SetObjectArrayElement(jout, 0, jarray);
|
||||||
|
|
||||||
return jresult;
|
return jresult;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -531,18 +567,20 @@ JNIEXPORT jfloatArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoost
|
|||||||
* Method: XGBoosterLoadModel
|
* Method: XGBoosterLoadModel
|
||||||
* Signature: (JLjava/lang/String;)V
|
* Signature: (JLjava/lang/String;)V
|
||||||
*/
|
*/
|
||||||
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModel
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModel
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) {
|
||||||
|
jint jresult = 0;
|
||||||
void *handle = (void *) 0 ;
|
void *handle = (void *) 0 ;
|
||||||
char *fname = (char *) 0 ;
|
char *fname = (char *) 0 ;
|
||||||
handle = *(void **)&jhandle;
|
handle = *(void **)&jhandle;
|
||||||
fname = 0;
|
|
||||||
if (jfname) {
|
|
||||||
fname = (char *)jenv->GetStringUTFChars(jfname, 0);
|
fname = (char *)jenv->GetStringUTFChars(jfname, 0);
|
||||||
if (!fname) return ;
|
|
||||||
}
|
|
||||||
XGBoosterLoadModel(handle,(char const *)fname);
|
jresult = (jint) XGBoosterLoadModel(handle,(char const *)fname);
|
||||||
if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname);
|
if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname);
|
||||||
|
|
||||||
|
return jresult;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -550,18 +588,19 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadM
|
|||||||
* Method: XGBoosterSaveModel
|
* Method: XGBoosterSaveModel
|
||||||
* Signature: (JLjava/lang/String;)V
|
* Signature: (JLjava/lang/String;)V
|
||||||
*/
|
*/
|
||||||
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSaveModel
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSaveModel
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) {
|
||||||
|
jint jresult = 0;
|
||||||
void *handle = (void *) 0 ;
|
void *handle = (void *) 0 ;
|
||||||
char *fname = (char *) 0 ;
|
char *fname = (char *) 0 ;
|
||||||
handle = *(void **)&jhandle;
|
handle = *(void **)&jhandle;
|
||||||
fname = 0;
|
fname = 0;
|
||||||
if (jfname) {
|
|
||||||
fname = (char *)jenv->GetStringUTFChars(jfname, 0);
|
fname = (char *)jenv->GetStringUTFChars(jfname, 0);
|
||||||
if (!fname) return ;
|
|
||||||
}
|
jresult = (jint) XGBoosterSaveModel(handle, (char const *)fname);
|
||||||
XGBoosterSaveModel(handle, (char const *)fname);
|
|
||||||
if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname);
|
if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname);
|
||||||
|
|
||||||
|
return jresult;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -569,7 +608,7 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSaveM
|
|||||||
* Method: XGBoosterLoadModelFromBuffer
|
* Method: XGBoosterLoadModelFromBuffer
|
||||||
* Signature: (JJJ)V
|
* Signature: (JJJ)V
|
||||||
*/
|
*/
|
||||||
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModelFromBuffer
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModelFromBuffer
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jbuf, jlong jlen) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jbuf, jlong jlen) {
|
||||||
void *handle = (void *) 0 ;
|
void *handle = (void *) 0 ;
|
||||||
void *buf = (void *) 0 ;
|
void *buf = (void *) 0 ;
|
||||||
@ -577,7 +616,7 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadM
|
|||||||
handle = *(void **)&jhandle;
|
handle = *(void **)&jhandle;
|
||||||
buf = *(void **)&jbuf;
|
buf = *(void **)&jbuf;
|
||||||
len = (bst_ulong)jlen;
|
len = (bst_ulong)jlen;
|
||||||
XGBoosterLoadModelFromBuffer(handle, (void const *)buf, len);
|
return (jint) XGBoosterLoadModelFromBuffer(handle, (void const *)buf, len);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@ -585,17 +624,21 @@ JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadM
|
|||||||
* Method: XGBoosterGetModelRaw
|
* Method: XGBoosterGetModelRaw
|
||||||
* Signature: (J)Ljava/lang/String;
|
* Signature: (J)Ljava/lang/String;
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGetModelRaw
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGetModelRaw
|
||||||
(JNIEnv * jenv, jclass jcls, jlong jhandle) {
|
(JNIEnv * jenv, jclass jcls, jlong jhandle, jobjectArray jout) {
|
||||||
jstring jresult = 0 ;
|
jint jresult = 0 ;
|
||||||
|
jstring jinfo = 0;
|
||||||
void *handle = (void *) 0 ;
|
void *handle = (void *) 0 ;
|
||||||
bst_ulong len[1];
|
bst_ulong len[1];
|
||||||
*len = 0;
|
*len = 0;
|
||||||
char *result = 0 ;
|
char *result[1];
|
||||||
handle = *(void **)&jhandle;
|
handle = *(void **)&jhandle;
|
||||||
|
|
||||||
result = (char *)XGBoosterGetModelRaw(handle, len);
|
jresult = (jint)XGBoosterGetModelRaw(handle, len, (const char **) result);
|
||||||
if (result) jresult = jenv->NewStringUTF((const char *)result);
|
if (*result){
|
||||||
|
jinfo = jenv->NewStringUTF((const char *) *result);
|
||||||
|
jenv->SetObjectArrayElement(jout, 0, jinfo);
|
||||||
|
}
|
||||||
return jresult;
|
return jresult;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -604,15 +647,16 @@ JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGe
|
|||||||
* Method: XGBoosterDumpModel
|
* Method: XGBoosterDumpModel
|
||||||
* Signature: (JLjava/lang/String;I)[Ljava/lang/String;
|
* Signature: (JLjava/lang/String;I)[Ljava/lang/String;
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jobjectArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterDumpModel
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterDumpModel
|
||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfmap, jint jwith_stats) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfmap, jint jwith_stats, jobjectArray jout) {
|
||||||
|
jint jresult = 0;
|
||||||
void *handle = (void *) 0 ;
|
void *handle = (void *) 0 ;
|
||||||
char *fmap = (char *) 0 ;
|
char *fmap = (char *) 0 ;
|
||||||
int with_stats ;
|
int with_stats ;
|
||||||
bst_ulong len[1];
|
bst_ulong len[1];
|
||||||
*len = 0;
|
*len = 0;
|
||||||
|
|
||||||
char **result = 0 ;
|
char **result[1];
|
||||||
handle = *(void **)&jhandle;
|
handle = *(void **)&jhandle;
|
||||||
fmap = 0;
|
fmap = 0;
|
||||||
if (jfmap) {
|
if (jfmap) {
|
||||||
@ -621,14 +665,16 @@ JNIEXPORT jobjectArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoos
|
|||||||
}
|
}
|
||||||
with_stats = (int)jwith_stats;
|
with_stats = (int)jwith_stats;
|
||||||
|
|
||||||
result = (char **)XGBoosterDumpModel(handle, (char const *)fmap, with_stats, len);
|
jresult = (jint) XGBoosterDumpModel(handle, (const char *)fmap, with_stats, len, (const char ***) result);
|
||||||
|
|
||||||
jsize jlen = (jsize)*len;
|
jsize jlen = (jsize)*len;
|
||||||
jobjectArray jresult = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF(""));
|
jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF(""));
|
||||||
for(int i=0 ; i<jlen; i++) {
|
for(int i=0 ; i<jlen; i++) {
|
||||||
jenv->SetObjectArrayElement(jresult, i, jenv->NewStringUTF((const char*)result[i]));
|
jenv->SetObjectArrayElement(jinfos, i, jenv->NewStringUTF((const char*) result[0][i]));
|
||||||
}
|
}
|
||||||
|
jenv->SetObjectArrayElement(jout, 0, jinfos);
|
||||||
|
|
||||||
if (fmap) jenv->ReleaseStringUTFChars(jfmap, (const char *)fmap);
|
if (fmap) jenv->ReleaseStringUTFChars(jfmap, (const char *)fmap);
|
||||||
|
|
||||||
return jresult;
|
return jresult;
|
||||||
}
|
}
|
||||||
@ -9,203 +9,211 @@ extern "C" {
|
|||||||
#endif
|
#endif
|
||||||
/*
|
/*
|
||||||
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
||||||
* Method: XGDMatrixCreateFromFile
|
* Method: XGBGetLastError
|
||||||
* Signature: (Ljava/lang/String;I)J
|
* Signature: ()Ljava/lang/String;
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromFile
|
JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBGetLastError
|
||||||
(JNIEnv *, jclass, jstring, jint);
|
(JNIEnv *, jclass);
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
||||||
|
* Method: XGDMatrixCreateFromFile
|
||||||
|
* Signature: (Ljava/lang/String;I[J)I
|
||||||
|
*/
|
||||||
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromFile
|
||||||
|
(JNIEnv *, jclass, jstring, jint, jlongArray);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
||||||
* Method: XGDMatrixCreateFromCSR
|
* Method: XGDMatrixCreateFromCSR
|
||||||
* Signature: ([J[J[F)J
|
* Signature: ([J[I[F[J)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSR
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSR
|
||||||
(JNIEnv *, jclass, jlongArray, jintArray, jfloatArray);
|
(JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jlongArray);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
||||||
* Method: XGDMatrixCreateFromCSC
|
* Method: XGDMatrixCreateFromCSC
|
||||||
* Signature: ([J[J[F)J
|
* Signature: ([J[I[F[J)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSC
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSC
|
||||||
(JNIEnv *, jclass, jlongArray, jintArray, jfloatArray);
|
(JNIEnv *, jclass, jlongArray, jintArray, jfloatArray, jlongArray);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
||||||
* Method: XGDMatrixCreateFromMat
|
* Method: XGDMatrixCreateFromMat
|
||||||
* Signature: ([FIIF)J
|
* Signature: ([FIIF[J)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromMat
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromMat
|
||||||
(JNIEnv *, jclass, jfloatArray, jint, jint, jfloat);
|
(JNIEnv *, jclass, jfloatArray, jint, jint, jfloat, jlongArray);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
||||||
* Method: XGDMatrixSliceDMatrix
|
* Method: XGDMatrixSliceDMatrix
|
||||||
* Signature: (J[I)J
|
* Signature: (J[I[J)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSliceDMatrix
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSliceDMatrix
|
||||||
(JNIEnv *, jclass, jlong, jintArray);
|
(JNIEnv *, jclass, jlong, jintArray, jlongArray);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
||||||
* Method: XGDMatrixFree
|
* Method: XGDMatrixFree
|
||||||
* Signature: (J)V
|
* Signature: (J)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixFree
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixFree
|
||||||
(JNIEnv *, jclass, jlong);
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
||||||
* Method: XGDMatrixSaveBinary
|
* Method: XGDMatrixSaveBinary
|
||||||
* Signature: (JLjava/lang/String;I)V
|
* Signature: (JLjava/lang/String;I)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSaveBinary
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSaveBinary
|
||||||
(JNIEnv *, jclass, jlong, jstring, jint);
|
(JNIEnv *, jclass, jlong, jstring, jint);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
||||||
* Method: XGDMatrixSetFloatInfo
|
* Method: XGDMatrixSetFloatInfo
|
||||||
* Signature: (JLjava/lang/String;[F)V
|
* Signature: (JLjava/lang/String;[F)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetFloatInfo
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetFloatInfo
|
||||||
(JNIEnv *, jclass, jlong, jstring, jfloatArray);
|
(JNIEnv *, jclass, jlong, jstring, jfloatArray);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
||||||
* Method: XGDMatrixSetUIntInfo
|
* Method: XGDMatrixSetUIntInfo
|
||||||
* Signature: (JLjava/lang/String;[I)V
|
* Signature: (JLjava/lang/String;[I)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetUIntInfo
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetUIntInfo
|
||||||
(JNIEnv *, jclass, jlong, jstring, jintArray);
|
(JNIEnv *, jclass, jlong, jstring, jintArray);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
||||||
* Method: XGDMatrixSetGroup
|
* Method: XGDMatrixSetGroup
|
||||||
* Signature: (J[I)V
|
* Signature: (J[I)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGroup
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGroup
|
||||||
(JNIEnv *, jclass, jlong, jintArray);
|
(JNIEnv *, jclass, jlong, jintArray);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
||||||
* Method: XGDMatrixGetFloatInfo
|
* Method: XGDMatrixGetFloatInfo
|
||||||
* Signature: (JLjava/lang/String;)[F
|
* Signature: (JLjava/lang/String;[[F)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jfloatArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetFloatInfo
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetFloatInfo
|
||||||
(JNIEnv *, jclass, jlong, jstring);
|
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
||||||
* Method: XGDMatrixGetUIntInfo
|
* Method: XGDMatrixGetUIntInfo
|
||||||
* Signature: (JLjava/lang/String;)[I
|
* Signature: (JLjava/lang/String;[[I)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jintArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetUIntInfo
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetUIntInfo
|
||||||
(JNIEnv *, jclass, jlong, jstring);
|
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
||||||
* Method: XGDMatrixNumRow
|
* Method: XGDMatrixNumRow
|
||||||
* Signature: (J)J
|
* Signature: (J[J)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixNumRow
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixNumRow
|
||||||
(JNIEnv *, jclass, jlong);
|
(JNIEnv *, jclass, jlong, jlongArray);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
||||||
* Method: XGBoosterCreate
|
* Method: XGBoosterCreate
|
||||||
* Signature: ([J)J
|
* Signature: ([J[J)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jlong JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCreate
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCreate
|
||||||
(JNIEnv *, jclass, jlongArray);
|
(JNIEnv *, jclass, jlongArray, jlongArray);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
||||||
* Method: XGBoosterFree
|
* Method: XGBoosterFree
|
||||||
* Signature: (J)V
|
* Signature: (J)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterFree
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterFree
|
||||||
(JNIEnv *, jclass, jlong);
|
(JNIEnv *, jclass, jlong);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
||||||
* Method: XGBoosterSetParam
|
* Method: XGBoosterSetParam
|
||||||
* Signature: (JLjava/lang/String;Ljava/lang/String;)V
|
* Signature: (JLjava/lang/String;Ljava/lang/String;)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSetParam
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSetParam
|
||||||
(JNIEnv *, jclass, jlong, jstring, jstring);
|
(JNIEnv *, jclass, jlong, jstring, jstring);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
||||||
* Method: XGBoosterUpdateOneIter
|
* Method: XGBoosterUpdateOneIter
|
||||||
* Signature: (JIJ)V
|
* Signature: (JIJ)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdateOneIter
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterUpdateOneIter
|
||||||
(JNIEnv *, jclass, jlong, jint, jlong);
|
(JNIEnv *, jclass, jlong, jint, jlong);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
||||||
* Method: XGBoosterBoostOneIter
|
* Method: XGBoosterBoostOneIter
|
||||||
* Signature: (JJ[F[F)V
|
* Signature: (JJ[F[F)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoostOneIter
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterBoostOneIter
|
||||||
(JNIEnv *, jclass, jlong, jlong, jfloatArray, jfloatArray);
|
(JNIEnv *, jclass, jlong, jlong, jfloatArray, jfloatArray);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
||||||
* Method: XGBoosterEvalOneIter
|
* Method: XGBoosterEvalOneIter
|
||||||
* Signature: (JI[J[Ljava/lang/String;)Ljava/lang/String;
|
* Signature: (JI[J[Ljava/lang/String;[Ljava/lang/String;)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEvalOneIter
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEvalOneIter
|
||||||
(JNIEnv *, jclass, jlong, jint, jlongArray, jobjectArray);
|
(JNIEnv *, jclass, jlong, jint, jlongArray, jobjectArray, jobjectArray);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
||||||
* Method: XGBoosterPredict
|
* Method: XGBoosterPredict
|
||||||
* Signature: (JJIJ)[F
|
* Signature: (JJIJ[[F)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jfloatArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterPredict
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterPredict
|
||||||
(JNIEnv *, jclass, jlong, jlong, jint, jlong);
|
(JNIEnv *, jclass, jlong, jlong, jint, jlong, jobjectArray);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
||||||
* Method: XGBoosterLoadModel
|
* Method: XGBoosterLoadModel
|
||||||
* Signature: (JLjava/lang/String;)V
|
* Signature: (JLjava/lang/String;)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModel
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModel
|
||||||
(JNIEnv *, jclass, jlong, jstring);
|
(JNIEnv *, jclass, jlong, jstring);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
||||||
* Method: XGBoosterSaveModel
|
* Method: XGBoosterSaveModel
|
||||||
* Signature: (JLjava/lang/String;)V
|
* Signature: (JLjava/lang/String;)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSaveModel
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterSaveModel
|
||||||
(JNIEnv *, jclass, jlong, jstring);
|
(JNIEnv *, jclass, jlong, jstring);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
||||||
* Method: XGBoosterLoadModelFromBuffer
|
* Method: XGBoosterLoadModelFromBuffer
|
||||||
* Signature: (JJJ)V
|
* Signature: (JJJ)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT void JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModelFromBuffer
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadModelFromBuffer
|
||||||
(JNIEnv *, jclass, jlong, jlong, jlong);
|
(JNIEnv *, jclass, jlong, jlong, jlong);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
||||||
* Method: XGBoosterGetModelRaw
|
* Method: XGBoosterGetModelRaw
|
||||||
* Signature: (J)Ljava/lang/String;
|
* Signature: (J[Ljava/lang/String;)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGetModelRaw
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGetModelRaw
|
||||||
(JNIEnv *, jclass, jlong);
|
(JNIEnv *, jclass, jlong, jobjectArray);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
* Class: org_dmlc_xgboost4j_wrapper_XgboostJNI
|
||||||
* Method: XGBoosterDumpModel
|
* Method: XGBoosterDumpModel
|
||||||
* Signature: (JLjava/lang/String;I)[Ljava/lang/String;
|
* Signature: (JLjava/lang/String;I[[Ljava/lang/String;)I
|
||||||
*/
|
*/
|
||||||
JNIEXPORT jobjectArray JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterDumpModel
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterDumpModel
|
||||||
(JNIEnv *, jclass, jlong, jstring, jint);
|
(JNIEnv *, jclass, jlong, jstring, jint, jobjectArray);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user