update java wrapper for new fault handle API
This commit is contained in:
@@ -30,6 +30,8 @@ import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
|
||||
import org.dmlc.xgboost4j.util.Initializer;
|
||||
import org.dmlc.xgboost4j.util.ErrorHandle;
|
||||
import org.dmlc.xgboost4j.util.XgboostError;
|
||||
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
|
||||
|
||||
|
||||
@@ -57,8 +59,9 @@ public final class Booster {
|
||||
* init Booster from dMatrixs
|
||||
* @param params parameters
|
||||
* @param dMatrixs DMatrix array
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public Booster(Iterable<Entry<String, Object>> params, DMatrix[] dMatrixs) {
|
||||
public Booster(Iterable<Entry<String, Object>> params, DMatrix[] dMatrixs) throws XgboostError {
|
||||
init(dMatrixs);
|
||||
setParam("seed","0");
|
||||
setParams(params);
|
||||
@@ -70,9 +73,11 @@ public final class Booster {
|
||||
* load model from modelPath
|
||||
* @param params parameters
|
||||
* @param modelPath booster modelPath (model generated by booster.saveModel)
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public Booster(Iterable<Entry<String, Object>> params, String modelPath) {
|
||||
handle = XgboostJNI.XGBoosterCreate(new long[] {});
|
||||
public Booster(Iterable<Entry<String, Object>> params, String modelPath) throws XgboostError {
|
||||
long[] out = new long[1];
|
||||
init(null);
|
||||
loadModel(modelPath);
|
||||
setParam("seed","0");
|
||||
setParams(params);
|
||||
@@ -81,28 +86,33 @@ public final class Booster {
|
||||
|
||||
|
||||
|
||||
private void init(DMatrix[] dMatrixs) {
|
||||
private void init(DMatrix[] dMatrixs) throws XgboostError {
|
||||
long[] handles = null;
|
||||
if(dMatrixs != null) {
|
||||
handles = dMatrixs2handles(dMatrixs);
|
||||
}
|
||||
handle = XgboostJNI.XGBoosterCreate(handles);
|
||||
long[] out = new long[1];
|
||||
ErrorHandle.checkCall(XgboostJNI.XGBoosterCreate(handles, out));
|
||||
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* set parameter
|
||||
* @param key param name
|
||||
* @param value param value
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public final void setParam(String key, String value) {
|
||||
XgboostJNI.XGBoosterSetParam(handle, key, value);
|
||||
public final void setParam(String key, String value) throws XgboostError {
|
||||
ErrorHandle.checkCall(XgboostJNI.XGBoosterSetParam(handle, key, value));
|
||||
}
|
||||
|
||||
/**
|
||||
* set parameters
|
||||
* @param params parameters key-value map
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public void setParams(Iterable<Entry<String, Object>> params) {
|
||||
public void setParams(Iterable<Entry<String, Object>> params) throws XgboostError {
|
||||
if(params!=null) {
|
||||
for(Map.Entry<String, Object> entry : params) {
|
||||
setParam(entry.getKey(), entry.getValue().toString());
|
||||
@@ -115,9 +125,10 @@ public final class Booster {
|
||||
* Update (one iteration)
|
||||
* @param dtrain training data
|
||||
* @param iter current iteration number
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public void update(DMatrix dtrain, int iter) {
|
||||
XgboostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle());
|
||||
public void update(DMatrix dtrain, int iter) throws XgboostError {
|
||||
ErrorHandle.checkCall(XgboostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -125,8 +136,9 @@ public final class Booster {
|
||||
* @param dtrain training data
|
||||
* @param iter current iteration number
|
||||
* @param obj customized objective class
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public void update(DMatrix dtrain, int iter, IObjective obj) {
|
||||
public void update(DMatrix dtrain, int iter, IObjective obj) throws XgboostError {
|
||||
float[][] predicts = predict(dtrain, true);
|
||||
List<float[]> gradients = obj.getGradient(predicts, dtrain);
|
||||
boost(dtrain, gradients.get(0), gradients.get(1));
|
||||
@@ -137,12 +149,13 @@ public final class Booster {
|
||||
* @param dtrain training data
|
||||
* @param grad first order of gradient
|
||||
* @param hess seconde order of gradient
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public void boost(DMatrix dtrain, float[] grad, float[] hess) {
|
||||
public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XgboostError {
|
||||
if(grad.length != hess.length) {
|
||||
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length, hess.length));
|
||||
}
|
||||
XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad, hess);
|
||||
ErrorHandle.checkCall(XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad, hess));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -151,11 +164,13 @@ public final class Booster {
|
||||
* @param evalNames name for eval dmatrixs, used for check results
|
||||
* @param iter current eval iteration
|
||||
* @return eval information
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) {
|
||||
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XgboostError {
|
||||
long[] handles = dMatrixs2handles(evalMatrixs);
|
||||
String evalInfo = XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames);
|
||||
return evalInfo;
|
||||
String[] evalInfo = new String[1];
|
||||
ErrorHandle.checkCall(XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames, evalInfo));
|
||||
return evalInfo[0];
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -165,8 +180,9 @@ public final class Booster {
|
||||
* @param iter
|
||||
* @param eval
|
||||
* @return eval information
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter, IEvaluation eval) {
|
||||
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter, IEvaluation eval) throws XgboostError {
|
||||
String evalInfo = "";
|
||||
for(int i=0; i<evalNames.length; i++) {
|
||||
String evalName = evalNames[i];
|
||||
@@ -184,10 +200,12 @@ public final class Booster {
|
||||
* @param evalNames name for eval dmatrixs, used for check results
|
||||
* @param iter current eval iteration
|
||||
* @return eval information
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public String evalSet(long[] dHandles, String[] evalNames, int iter) {
|
||||
String evalInfo = XgboostJNI.XGBoosterEvalOneIter(handle, iter, dHandles, evalNames);
|
||||
return evalInfo;
|
||||
public String evalSet(long[] dHandles, String[] evalNames, int iter) throws XgboostError {
|
||||
String[] evalInfo = new String[1];
|
||||
ErrorHandle.checkCall(XgboostJNI.XGBoosterEvalOneIter(handle, iter, dHandles, evalNames, evalInfo));
|
||||
return evalInfo[0];
|
||||
}
|
||||
|
||||
|
||||
@@ -197,8 +215,9 @@ public final class Booster {
|
||||
* @param evalName
|
||||
* @param iter
|
||||
* @return eval information
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public String eval(DMatrix evalMat, String evalName, int iter) {
|
||||
public String eval(DMatrix evalMat, String evalName, int iter) throws XgboostError {
|
||||
DMatrix[] evalMats = new DMatrix[] {evalMat};
|
||||
String[] evalNames = new String[] {evalName};
|
||||
return evalSet(evalMats, evalNames, iter);
|
||||
@@ -212,7 +231,7 @@ public final class Booster {
|
||||
* @param predLeaf
|
||||
* @return predict results
|
||||
*/
|
||||
private synchronized float[][] pred(DMatrix data, boolean outPutMargin, long treeLimit, boolean predLeaf) {
|
||||
private synchronized float[][] pred(DMatrix data, boolean outPutMargin, long treeLimit, boolean predLeaf) throws XgboostError {
|
||||
int optionMask = 0;
|
||||
if(outPutMargin) {
|
||||
optionMask = 1;
|
||||
@@ -220,15 +239,16 @@ public final class Booster {
|
||||
if(predLeaf) {
|
||||
optionMask = 2;
|
||||
}
|
||||
float[] rawPredicts = XgboostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask, treeLimit);
|
||||
float[][] rawPredicts = new float[1][];
|
||||
ErrorHandle.checkCall(XgboostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask, treeLimit, rawPredicts));
|
||||
int row = (int) data.rowNum();
|
||||
int col = (int) rawPredicts.length/row;
|
||||
int col = (int) rawPredicts[0].length/row;
|
||||
float[][] predicts = new float[row][col];
|
||||
int r,c;
|
||||
for(int i=0; i< rawPredicts.length; i++) {
|
||||
for(int i=0; i< rawPredicts[0].length; i++) {
|
||||
r = i/col;
|
||||
c = i%col;
|
||||
predicts[r][c] = rawPredicts[i];
|
||||
predicts[r][c] = rawPredicts[0][i];
|
||||
}
|
||||
return predicts;
|
||||
}
|
||||
@@ -237,8 +257,9 @@ public final class Booster {
|
||||
* Predict with data
|
||||
* @param data dmatrix storing the input
|
||||
* @return predict result
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public float[][] predict(DMatrix data) {
|
||||
public float[][] predict(DMatrix data) throws XgboostError {
|
||||
return pred(data, false, 0, false);
|
||||
}
|
||||
|
||||
@@ -247,8 +268,9 @@ public final class Booster {
|
||||
* @param data dmatrix storing the input
|
||||
* @param outPutMargin Whether to output the raw untransformed margin value.
|
||||
* @return predict result
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public float[][] predict(DMatrix data, boolean outPutMargin) {
|
||||
public float[][] predict(DMatrix data, boolean outPutMargin) throws XgboostError {
|
||||
return pred(data, outPutMargin, 0, false);
|
||||
}
|
||||
|
||||
@@ -258,8 +280,9 @@ public final class Booster {
|
||||
* @param outPutMargin Whether to output the raw untransformed margin value.
|
||||
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
|
||||
* @return predict result
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public float[][] predict(DMatrix data, boolean outPutMargin, long treeLimit) {
|
||||
public float[][] predict(DMatrix data, boolean outPutMargin, long treeLimit) throws XgboostError {
|
||||
return pred(data, outPutMargin, treeLimit, false);
|
||||
}
|
||||
|
||||
@@ -272,8 +295,9 @@ public final class Booster {
|
||||
Note that the leaf index of a tree is unique per tree, so you may find leaf 1
|
||||
in both tree 1 and tree 0.
|
||||
* @return predict result
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public float[][] predict(DMatrix data , long treeLimit, boolean predLeaf) {
|
||||
public float[][] predict(DMatrix data , long treeLimit, boolean predLeaf) throws XgboostError {
|
||||
return pred(data, false, treeLimit, predLeaf);
|
||||
}
|
||||
|
||||
@@ -293,14 +317,16 @@ public final class Booster {
|
||||
* get the dump of the model as a string array
|
||||
* @param withStats Controls whether the split statistics are output.
|
||||
* @return dumped model information
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public String[] getDumpInfo(boolean withStats) {
|
||||
public String[] getDumpInfo(boolean withStats) throws XgboostError {
|
||||
int statsFlag = 0;
|
||||
if(withStats) {
|
||||
statsFlag = 1;
|
||||
}
|
||||
String[] modelInfos = XgboostJNI.XGBoosterDumpModel(handle, "", statsFlag);
|
||||
return modelInfos;
|
||||
String[][] modelInfos = new String[1][];
|
||||
ErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos));
|
||||
return modelInfos[0];
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -308,14 +334,16 @@ public final class Booster {
|
||||
* @param featureMap featureMap file
|
||||
* @param withStats Controls whether the split statistics are output.
|
||||
* @return dumped model information
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public String[] getDumpInfo(String featureMap, boolean withStats) {
|
||||
public String[] getDumpInfo(String featureMap, boolean withStats) throws XgboostError {
|
||||
int statsFlag = 0;
|
||||
if(withStats) {
|
||||
statsFlag = 1;
|
||||
}
|
||||
String[] modelInfos = XgboostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag);
|
||||
return modelInfos;
|
||||
String[][] modelInfos = new String[1][];
|
||||
ErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag, modelInfos));
|
||||
return modelInfos[0];
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -326,8 +354,9 @@ public final class Booster {
|
||||
* @throws FileNotFoundException
|
||||
* @throws UnsupportedEncodingException
|
||||
* @throws IOException
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public void dumpModel(String modelPath, boolean withStats) throws FileNotFoundException, UnsupportedEncodingException, IOException {
|
||||
public void dumpModel(String modelPath, boolean withStats) throws FileNotFoundException, UnsupportedEncodingException, IOException, XgboostError {
|
||||
File tf = new File(modelPath);
|
||||
FileOutputStream out = new FileOutputStream(tf);
|
||||
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
|
||||
@@ -352,8 +381,9 @@ public final class Booster {
|
||||
* @throws FileNotFoundException
|
||||
* @throws UnsupportedEncodingException
|
||||
* @throws IOException
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public void dumpModel(String modelPath, String featureMap, boolean withStats) throws FileNotFoundException, UnsupportedEncodingException, IOException {
|
||||
public void dumpModel(String modelPath, String featureMap, boolean withStats) throws FileNotFoundException, UnsupportedEncodingException, IOException, XgboostError {
|
||||
File tf = new File(modelPath);
|
||||
FileOutputStream out = new FileOutputStream(tf);
|
||||
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
|
||||
@@ -372,8 +402,9 @@ public final class Booster {
|
||||
/**
|
||||
* get importance of each feature
|
||||
* @return featureMap key: feature index, value: feature importance score
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public Map<String, Integer> getFeatureScore() {
|
||||
public Map<String, Integer> getFeatureScore() throws XgboostError {
|
||||
String[] modelInfos = getDumpInfo(false);
|
||||
Map<String, Integer> featureScore = new HashMap<>();
|
||||
for(String tree : modelInfos) {
|
||||
@@ -400,8 +431,9 @@ public final class Booster {
|
||||
* get importance of each feature
|
||||
* @param featureMap file to save dumped model info
|
||||
* @return featureMap key: feature index, value: feature importance score
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public Map<String, Integer> getFeatureScore(String featureMap) {
|
||||
public Map<String, Integer> getFeatureScore(String featureMap) throws XgboostError {
|
||||
String[] modelInfos = getDumpInfo(featureMap, false);
|
||||
Map<String, Integer> featureScore = new HashMap<>();
|
||||
for(String tree : modelInfos) {
|
||||
|
||||
@@ -18,6 +18,8 @@ package org.dmlc.xgboost4j;
|
||||
import java.io.IOException;
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.dmlc.xgboost4j.util.ErrorHandle;
|
||||
import org.dmlc.xgboost4j.util.XgboostError;
|
||||
import org.dmlc.xgboost4j.util.Initializer;
|
||||
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
|
||||
|
||||
@@ -50,9 +52,12 @@ public class DMatrix {
|
||||
/**
|
||||
* init DMatrix from file (svmlight format)
|
||||
* @param dataPath
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public DMatrix(String dataPath) {
|
||||
handle = XgboostJNI.XGDMatrixCreateFromFile(dataPath, 1);
|
||||
public DMatrix(String dataPath) throws XgboostError {
|
||||
long[] out = new long[1];
|
||||
ErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromFile(dataPath, 1, out));
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -61,17 +66,20 @@ public class DMatrix {
|
||||
* @param indices Indices (colIndexs for CSR or rowIndexs for CSC)
|
||||
* @param data non zero values (sequence by row for CSR or by col for CSC)
|
||||
* @param st sparse matrix type (CSR or CSC)
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) {
|
||||
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XgboostError {
|
||||
long[] out = new long[1];
|
||||
if(st == SparseType.CSR) {
|
||||
handle = XgboostJNI.XGDMatrixCreateFromCSR(headers, indices, data);
|
||||
ErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromCSR(headers, indices, data, out));
|
||||
}
|
||||
else if(st == SparseType.CSC) {
|
||||
handle = XgboostJNI.XGDMatrixCreateFromCSC(headers, indices, data);
|
||||
ErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromCSC(headers, indices, data, out));
|
||||
}
|
||||
else {
|
||||
throw new UnknownError("unknow sparsetype");
|
||||
}
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -79,10 +87,13 @@ public class DMatrix {
|
||||
* @param data data values
|
||||
* @param nrow number of rows
|
||||
* @param ncol number of columns
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public DMatrix(float[] data, int nrow, int ncol) {
|
||||
handle = XgboostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f);
|
||||
}
|
||||
public DMatrix(float[] data, int nrow, int ncol) throws XgboostError {
|
||||
long[] out = new long[1];
|
||||
ErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f, out));
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* used for DMatrix slice
|
||||
@@ -98,33 +109,36 @@ public class DMatrix {
|
||||
* set label of dmatrix
|
||||
* @param labels
|
||||
*/
|
||||
public void setLabel(float[] labels) {
|
||||
XgboostJNI.XGDMatrixSetFloatInfo(handle, "label", labels);
|
||||
public void setLabel(float[] labels) throws XgboostError {
|
||||
ErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "label", labels));
|
||||
}
|
||||
|
||||
/**
|
||||
* set weight of each instance
|
||||
* @param weights
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public void setWeight(float[] weights) {
|
||||
XgboostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights);
|
||||
public void setWeight(float[] weights) throws XgboostError {
|
||||
ErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights));
|
||||
}
|
||||
|
||||
/**
|
||||
* if specified, xgboost will start from this init margin
|
||||
* can be used to specify initial prediction to boost from
|
||||
* @param baseMargin
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public void setBaseMargin(float[] baseMargin) {
|
||||
XgboostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin);
|
||||
public void setBaseMargin(float[] baseMargin) throws XgboostError {
|
||||
ErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin));
|
||||
}
|
||||
|
||||
/**
|
||||
* if specified, xgboost will start from this init margin
|
||||
* can be used to specify initial prediction to boost from
|
||||
* @param baseMargin
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public void setBaseMargin(float[][] baseMargin) {
|
||||
public void setBaseMargin(float[][] baseMargin) throws XgboostError {
|
||||
float[] flattenMargin = flatten(baseMargin);
|
||||
setBaseMargin(flattenMargin);
|
||||
}
|
||||
@@ -132,42 +146,48 @@ public class DMatrix {
|
||||
/**
|
||||
* Set group sizes of DMatrix (used for ranking)
|
||||
* @param group
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public void setGroup(int[] group) {
|
||||
XgboostJNI.XGDMatrixSetGroup(handle, group);
|
||||
public void setGroup(int[] group) throws XgboostError {
|
||||
ErrorHandle.checkCall(XgboostJNI.XGDMatrixSetGroup(handle, group));
|
||||
}
|
||||
|
||||
private float[] getFloatInfo(String field) {
|
||||
float[] infos = XgboostJNI.XGDMatrixGetFloatInfo(handle, field);
|
||||
return infos;
|
||||
private float[] getFloatInfo(String field) throws XgboostError {
|
||||
float[][] infos = new float[1][];
|
||||
ErrorHandle.checkCall(XgboostJNI.XGDMatrixGetFloatInfo(handle, field, infos));
|
||||
return infos[0];
|
||||
}
|
||||
|
||||
private int[] getIntInfo(String field) {
|
||||
int[] infos = XgboostJNI.XGDMatrixGetUIntInfo(handle, field);
|
||||
return infos;
|
||||
private int[] getIntInfo(String field) throws XgboostError {
|
||||
int[][] infos = new int[1][];
|
||||
ErrorHandle.checkCall(XgboostJNI.XGDMatrixGetUIntInfo(handle, field, infos));
|
||||
return infos[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* get label values
|
||||
* @return label
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public float[] getLabel() {
|
||||
public float[] getLabel() throws XgboostError {
|
||||
return getFloatInfo("label");
|
||||
}
|
||||
|
||||
/**
|
||||
* get weight of the DMatrix
|
||||
* @return weights
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public float[] getWeight() {
|
||||
public float[] getWeight() throws XgboostError {
|
||||
return getFloatInfo("weight");
|
||||
}
|
||||
|
||||
/**
|
||||
* get base margin of the DMatrix
|
||||
* @return base margin
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public float[] getBaseMargin() {
|
||||
public float[] getBaseMargin() throws XgboostError {
|
||||
return getFloatInfo("base_margin");
|
||||
}
|
||||
|
||||
@@ -175,9 +195,12 @@ public class DMatrix {
|
||||
* Slice the DMatrix and return a new DMatrix that only contains `rowIndex`.
|
||||
* @param rowIndex
|
||||
* @return sliced new DMatrix
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public DMatrix slice(int[] rowIndex) {
|
||||
long sHandle = XgboostJNI.XGDMatrixSliceDMatrix(handle, rowIndex);
|
||||
public DMatrix slice(int[] rowIndex) throws XgboostError {
|
||||
long[] out = new long[1];
|
||||
ErrorHandle.checkCall(XgboostJNI.XGDMatrixSliceDMatrix(handle, rowIndex, out));
|
||||
long sHandle = out[0];
|
||||
DMatrix sMatrix = new DMatrix(sHandle);
|
||||
return sMatrix;
|
||||
}
|
||||
@@ -185,9 +208,12 @@ public class DMatrix {
|
||||
/**
|
||||
* get the row number of DMatrix
|
||||
* @return number of rows
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public long rowNum() {
|
||||
return XgboostJNI.XGDMatrixNumRow(handle);
|
||||
public long rowNum() throws XgboostError {
|
||||
long[] rowNum = new long[1];
|
||||
ErrorHandle.checkCall(XgboostJNI.XGDMatrixNumRow(handle,rowNum));
|
||||
return rowNum[0];
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -37,8 +37,9 @@ public class CVPack {
|
||||
* @param dtrain train data
|
||||
* @param dtest test data
|
||||
* @param params parameters
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public CVPack(DMatrix dtrain, DMatrix dtest, Iterable<Map.Entry<String, Object>> params) {
|
||||
public CVPack(DMatrix dtrain, DMatrix dtest, Iterable<Map.Entry<String, Object>> params) throws XgboostError {
|
||||
dmats = new DMatrix[] {dtrain, dtest};
|
||||
booster = new Booster(params, dmats);
|
||||
names = new String[] {"train", "test"};
|
||||
@@ -49,8 +50,9 @@ public class CVPack {
|
||||
/**
|
||||
* update one iteration
|
||||
* @param iter iteration num
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public void update(int iter) {
|
||||
public void update(int iter) throws XgboostError {
|
||||
booster.update(dtrain, iter);
|
||||
}
|
||||
|
||||
@@ -58,8 +60,9 @@ public class CVPack {
|
||||
* update one iteration
|
||||
* @param iter iteration num
|
||||
* @param obj customized objective
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public void update(int iter, IObjective obj) {
|
||||
public void update(int iter, IObjective obj) throws XgboostError {
|
||||
booster.update(dtrain, iter, obj);
|
||||
}
|
||||
|
||||
@@ -67,8 +70,9 @@ public class CVPack {
|
||||
* evaluation
|
||||
* @param iter iteration num
|
||||
* @return
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public String eval(int iter) {
|
||||
public String eval(int iter) throws XgboostError {
|
||||
return booster.evalSet(dmats, names, iter);
|
||||
}
|
||||
|
||||
@@ -77,8 +81,9 @@ public class CVPack {
|
||||
* @param iter iteration num
|
||||
* @param eval customized eval
|
||||
* @return
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public String eval(int iter, IEvaluation eval) {
|
||||
public String eval(int iter, IEvaluation eval) throws XgboostError {
|
||||
return booster.evalSet(dmats, names, iter, eval);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j.util;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
|
||||
|
||||
/**
|
||||
* error handle for Xgboost
|
||||
* @author hzx
|
||||
*/
|
||||
public class ErrorHandle {
|
||||
private static final Log logger = LogFactory.getLog(ErrorHandle.class);
|
||||
|
||||
//load native library
|
||||
static {
|
||||
try {
|
||||
Initializer.InitXgboost();
|
||||
} catch (IOException ex) {
|
||||
logger.error("load native library failed.");
|
||||
logger.error(ex);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* check the return value of C API
|
||||
* @param ret return valud of xgboostJNI C API call
|
||||
* @throws org.dmlc.xgboost4j.util.XgboostError
|
||||
*/
|
||||
public static void checkCall(int ret) throws XgboostError {
|
||||
if(ret != 0) {
|
||||
throw new XgboostError(XgboostJNI.XGBGetLastError());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -47,7 +47,7 @@ public class Trainer {
|
||||
* @return trained booster
|
||||
*/
|
||||
public static Booster train(Iterable<Entry<String, Object>> params, DMatrix dtrain, int round,
|
||||
Iterable<Entry<String, DMatrix>> watchs, IObjective obj, IEvaluation eval) {
|
||||
Iterable<Entry<String, DMatrix>> watchs, IObjective obj, IEvaluation eval) throws XgboostError {
|
||||
|
||||
//collect eval matrixs
|
||||
String[] evalNames;
|
||||
@@ -112,7 +112,7 @@ public class Trainer {
|
||||
* @param eval customized evaluation (set to null if not used)
|
||||
* @return evaluation history
|
||||
*/
|
||||
public static String[] crossValiation(Iterable<Entry<String, Object>> params, DMatrix data, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) {
|
||||
public static String[] crossValiation(Iterable<Entry<String, Object>> params, DMatrix data, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) throws XgboostError {
|
||||
CVPack[] cvPacks = makeNFold(data, nfold, params, metrics);
|
||||
String[] evalHist = new String[round];
|
||||
String[] results = new String[cvPacks.length];
|
||||
@@ -149,7 +149,7 @@ public class Trainer {
|
||||
* @param evalMetrics Evaluation metrics
|
||||
* @return CV package array
|
||||
*/
|
||||
public static CVPack[] makeNFold(DMatrix data, int nfold, Iterable<Entry<String, Object>> params, String[] evalMetrics) {
|
||||
public static CVPack[] makeNFold(DMatrix data, int nfold, Iterable<Entry<String, Object>> params, String[] evalMetrics) throws XgboostError {
|
||||
List<Integer> samples = genRandPermutationNums(0, (int) data.rowNum());
|
||||
int step = samples.size()/nfold;
|
||||
int[] testSlice = new int[step];
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
/*
|
||||
Copyright (c) 2014 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
package org.dmlc.xgboost4j.util;
|
||||
|
||||
/**
|
||||
* custom error class for xgboost
|
||||
* @author hzx
|
||||
*/
|
||||
public class XgboostError extends Exception{
|
||||
public XgboostError(String message) {
|
||||
super(message);
|
||||
}
|
||||
}
|
||||
@@ -17,32 +17,34 @@ package org.dmlc.xgboost4j.wrapper;
|
||||
|
||||
/**
|
||||
* xgboost jni wrapper functions for xgboost_wrapper.h
|
||||
* change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster
|
||||
* @author hzx
|
||||
*/
|
||||
public class XgboostJNI {
|
||||
public final static native long XGDMatrixCreateFromFile(String fname, int silent);
|
||||
public final static native long XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data);
|
||||
public final static native long XGDMatrixCreateFromCSC(long[] colptr, int[] indices, float[] data);
|
||||
public final static native long XGDMatrixCreateFromMat(float[] data, int nrow, int ncol, float missing);
|
||||
public final static native long XGDMatrixSliceDMatrix(long handle, int[] idxset);
|
||||
public final static native void XGDMatrixFree(long handle);
|
||||
public final static native void XGDMatrixSaveBinary(long handle, String fname, int silent);
|
||||
public final static native void XGDMatrixSetFloatInfo(long handle, String field, float[] array);
|
||||
public final static native void XGDMatrixSetUIntInfo(long handle, String field, int[] array);
|
||||
public final static native void XGDMatrixSetGroup(long handle, int[] group);
|
||||
public final static native float[] XGDMatrixGetFloatInfo(long handle, String field);
|
||||
public final static native int[] XGDMatrixGetUIntInfo(long handle, String filed);
|
||||
public final static native long XGDMatrixNumRow(long handle);
|
||||
public final static native long XGBoosterCreate(long[] handles);
|
||||
public final static native void XGBoosterFree(long handle);
|
||||
public final static native void XGBoosterSetParam(long handle, String name, String value);
|
||||
public final static native void XGBoosterUpdateOneIter(long handle, int iter, long dtrain);
|
||||
public final static native void XGBoosterBoostOneIter(long handle, long dtrain, float[] grad, float[] hess);
|
||||
public final static native String XGBoosterEvalOneIter(long handle, int iter, long[] dmats, String[] evnames);
|
||||
public final static native float[] XGBoosterPredict(long handle, long dmat, int option_mask, long ntree_limit);
|
||||
public final static native void XGBoosterLoadModel(long handle, String fname);
|
||||
public final static native void XGBoosterSaveModel(long handle, String fname);
|
||||
public final static native void XGBoosterLoadModelFromBuffer(long handle, long buf, long len);
|
||||
public final static native String XGBoosterGetModelRaw(long handle);
|
||||
public final static native String[] XGBoosterDumpModel(long handle, String fmap, int with_stats);
|
||||
public final static native String XGBGetLastError();
|
||||
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
|
||||
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data, long[] out);
|
||||
public final static native int XGDMatrixCreateFromCSC(long[] colptr, int[] indices, float[] data, long[] out);
|
||||
public final static native int XGDMatrixCreateFromMat(float[] data, int nrow, int ncol, float missing, long[] out);
|
||||
public final static native int XGDMatrixSliceDMatrix(long handle, int[] idxset, long[] out);
|
||||
public final static native int XGDMatrixFree(long handle);
|
||||
public final static native int XGDMatrixSaveBinary(long handle, String fname, int silent);
|
||||
public final static native int XGDMatrixSetFloatInfo(long handle, String field, float[] array);
|
||||
public final static native int XGDMatrixSetUIntInfo(long handle, String field, int[] array);
|
||||
public final static native int XGDMatrixSetGroup(long handle, int[] group);
|
||||
public final static native int XGDMatrixGetFloatInfo(long handle, String field, float[][] info);
|
||||
public final static native int XGDMatrixGetUIntInfo(long handle, String filed, int[][] info);
|
||||
public final static native int XGDMatrixNumRow(long handle, long[] row);
|
||||
public final static native int XGBoosterCreate(long[] handles, long[] out);
|
||||
public final static native int XGBoosterFree(long handle);
|
||||
public final static native int XGBoosterSetParam(long handle, String name, String value);
|
||||
public final static native int XGBoosterUpdateOneIter(long handle, int iter, long dtrain);
|
||||
public final static native int XGBoosterBoostOneIter(long handle, long dtrain, float[] grad, float[] hess);
|
||||
public final static native int XGBoosterEvalOneIter(long handle, int iter, long[] dmats, String[] evnames, String[] eval_info);
|
||||
public final static native int XGBoosterPredict(long handle, long dmat, int option_mask, long ntree_limit, float[][] predicts);
|
||||
public final static native int XGBoosterLoadModel(long handle, String fname);
|
||||
public final static native int XGBoosterSaveModel(long handle, String fname);
|
||||
public final static native int XGBoosterLoadModelFromBuffer(long handle, long buf, long len);
|
||||
public final static native int XGBoosterGetModelRaw(long handle, String[] out_string);
|
||||
public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats, String[][] out_strings);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user