update java wrapper for new fault handle API

This commit is contained in:
yanqingmen
2015-07-06 02:32:58 -07:00
parent 7755c00721
commit f73bcd427d
19 changed files with 558 additions and 329 deletions

View File

@@ -30,6 +30,8 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.dmlc.xgboost4j.util.Initializer;
import org.dmlc.xgboost4j.util.ErrorHandle;
import org.dmlc.xgboost4j.util.XgboostError;
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
@@ -57,8 +59,9 @@ public final class Booster {
* init Booster from dMatrixs
* @param params parameters
* @param dMatrixs DMatrix array
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public Booster(Iterable<Entry<String, Object>> params, DMatrix[] dMatrixs) {
public Booster(Iterable<Entry<String, Object>> params, DMatrix[] dMatrixs) throws XgboostError {
init(dMatrixs);
setParam("seed","0");
setParams(params);
@@ -70,9 +73,11 @@ public final class Booster {
* load model from modelPath
* @param params parameters
* @param modelPath booster modelPath (model generated by booster.saveModel)
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public Booster(Iterable<Entry<String, Object>> params, String modelPath) {
handle = XgboostJNI.XGBoosterCreate(new long[] {});
public Booster(Iterable<Entry<String, Object>> params, String modelPath) throws XgboostError {
long[] out = new long[1];
init(null);
loadModel(modelPath);
setParam("seed","0");
setParams(params);
@@ -81,28 +86,33 @@ public final class Booster {
private void init(DMatrix[] dMatrixs) {
private void init(DMatrix[] dMatrixs) throws XgboostError {
long[] handles = null;
if(dMatrixs != null) {
handles = dMatrixs2handles(dMatrixs);
}
handle = XgboostJNI.XGBoosterCreate(handles);
long[] out = new long[1];
ErrorHandle.checkCall(XgboostJNI.XGBoosterCreate(handles, out));
handle = out[0];
}
/**
* set parameter
* @param key param name
* @param value param value
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public final void setParam(String key, String value) {
XgboostJNI.XGBoosterSetParam(handle, key, value);
public final void setParam(String key, String value) throws XgboostError {
ErrorHandle.checkCall(XgboostJNI.XGBoosterSetParam(handle, key, value));
}
/**
* set parameters
* @param params parameters key-value map
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public void setParams(Iterable<Entry<String, Object>> params) {
public void setParams(Iterable<Entry<String, Object>> params) throws XgboostError {
if(params!=null) {
for(Map.Entry<String, Object> entry : params) {
setParam(entry.getKey(), entry.getValue().toString());
@@ -115,9 +125,10 @@ public final class Booster {
* Update (one iteration)
* @param dtrain training data
* @param iter current iteration number
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public void update(DMatrix dtrain, int iter) {
XgboostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle());
public void update(DMatrix dtrain, int iter) throws XgboostError {
ErrorHandle.checkCall(XgboostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
}
/**
@@ -125,8 +136,9 @@ public final class Booster {
* @param dtrain training data
* @param iter current iteration number
* @param obj customized objective class
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public void update(DMatrix dtrain, int iter, IObjective obj) {
public void update(DMatrix dtrain, int iter, IObjective obj) throws XgboostError {
float[][] predicts = predict(dtrain, true);
List<float[]> gradients = obj.getGradient(predicts, dtrain);
boost(dtrain, gradients.get(0), gradients.get(1));
@@ -137,12 +149,13 @@ public final class Booster {
* @param dtrain training data
* @param grad first order of gradient
* @param hess seconde order of gradient
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public void boost(DMatrix dtrain, float[] grad, float[] hess) {
public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XgboostError {
if(grad.length != hess.length) {
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length, hess.length));
}
XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad, hess);
ErrorHandle.checkCall(XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad, hess));
}
/**
@@ -151,11 +164,13 @@ public final class Booster {
* @param evalNames name for eval dmatrixs, used for check results
* @param iter current eval iteration
* @return eval information
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) {
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XgboostError {
long[] handles = dMatrixs2handles(evalMatrixs);
String evalInfo = XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames);
return evalInfo;
String[] evalInfo = new String[1];
ErrorHandle.checkCall(XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames, evalInfo));
return evalInfo[0];
}
/**
@@ -165,8 +180,9 @@ public final class Booster {
* @param iter
* @param eval
* @return eval information
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter, IEvaluation eval) {
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter, IEvaluation eval) throws XgboostError {
String evalInfo = "";
for(int i=0; i<evalNames.length; i++) {
String evalName = evalNames[i];
@@ -184,10 +200,12 @@ public final class Booster {
* @param evalNames name for eval dmatrixs, used for check results
* @param iter current eval iteration
* @return eval information
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public String evalSet(long[] dHandles, String[] evalNames, int iter) {
String evalInfo = XgboostJNI.XGBoosterEvalOneIter(handle, iter, dHandles, evalNames);
return evalInfo;
public String evalSet(long[] dHandles, String[] evalNames, int iter) throws XgboostError {
String[] evalInfo = new String[1];
ErrorHandle.checkCall(XgboostJNI.XGBoosterEvalOneIter(handle, iter, dHandles, evalNames, evalInfo));
return evalInfo[0];
}
@@ -197,8 +215,9 @@ public final class Booster {
* @param evalName
* @param iter
* @return eval information
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public String eval(DMatrix evalMat, String evalName, int iter) {
public String eval(DMatrix evalMat, String evalName, int iter) throws XgboostError {
DMatrix[] evalMats = new DMatrix[] {evalMat};
String[] evalNames = new String[] {evalName};
return evalSet(evalMats, evalNames, iter);
@@ -212,7 +231,7 @@ public final class Booster {
* @param predLeaf
* @return predict results
*/
private synchronized float[][] pred(DMatrix data, boolean outPutMargin, long treeLimit, boolean predLeaf) {
private synchronized float[][] pred(DMatrix data, boolean outPutMargin, long treeLimit, boolean predLeaf) throws XgboostError {
int optionMask = 0;
if(outPutMargin) {
optionMask = 1;
@@ -220,15 +239,16 @@ public final class Booster {
if(predLeaf) {
optionMask = 2;
}
float[] rawPredicts = XgboostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask, treeLimit);
float[][] rawPredicts = new float[1][];
ErrorHandle.checkCall(XgboostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask, treeLimit, rawPredicts));
int row = (int) data.rowNum();
int col = (int) rawPredicts.length/row;
int col = (int) rawPredicts[0].length/row;
float[][] predicts = new float[row][col];
int r,c;
for(int i=0; i< rawPredicts.length; i++) {
for(int i=0; i< rawPredicts[0].length; i++) {
r = i/col;
c = i%col;
predicts[r][c] = rawPredicts[i];
predicts[r][c] = rawPredicts[0][i];
}
return predicts;
}
@@ -237,8 +257,9 @@ public final class Booster {
* Predict with data
* @param data dmatrix storing the input
* @return predict result
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public float[][] predict(DMatrix data) {
public float[][] predict(DMatrix data) throws XgboostError {
return pred(data, false, 0, false);
}
@@ -247,8 +268,9 @@ public final class Booster {
* @param data dmatrix storing the input
* @param outPutMargin Whether to output the raw untransformed margin value.
* @return predict result
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public float[][] predict(DMatrix data, boolean outPutMargin) {
public float[][] predict(DMatrix data, boolean outPutMargin) throws XgboostError {
return pred(data, outPutMargin, 0, false);
}
@@ -258,8 +280,9 @@ public final class Booster {
* @param outPutMargin Whether to output the raw untransformed margin value.
* @param treeLimit Limit number of trees in the prediction; defaults to 0 (use all trees).
* @return predict result
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public float[][] predict(DMatrix data, boolean outPutMargin, long treeLimit) {
public float[][] predict(DMatrix data, boolean outPutMargin, long treeLimit) throws XgboostError {
return pred(data, outPutMargin, treeLimit, false);
}
@@ -272,8 +295,9 @@ public final class Booster {
Note that the leaf index of a tree is unique per tree, so you may find leaf 1
in both tree 1 and tree 0.
* @return predict result
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public float[][] predict(DMatrix data , long treeLimit, boolean predLeaf) {
public float[][] predict(DMatrix data , long treeLimit, boolean predLeaf) throws XgboostError {
return pred(data, false, treeLimit, predLeaf);
}
@@ -293,14 +317,16 @@ public final class Booster {
* get the dump of the model as a string array
* @param withStats Controls whether the split statistics are output.
* @return dumped model information
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public String[] getDumpInfo(boolean withStats) {
public String[] getDumpInfo(boolean withStats) throws XgboostError {
int statsFlag = 0;
if(withStats) {
statsFlag = 1;
}
String[] modelInfos = XgboostJNI.XGBoosterDumpModel(handle, "", statsFlag);
return modelInfos;
String[][] modelInfos = new String[1][];
ErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos));
return modelInfos[0];
}
/**
@@ -308,14 +334,16 @@ public final class Booster {
* @param featureMap featureMap file
* @param withStats Controls whether the split statistics are output.
* @return dumped model information
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public String[] getDumpInfo(String featureMap, boolean withStats) {
public String[] getDumpInfo(String featureMap, boolean withStats) throws XgboostError {
int statsFlag = 0;
if(withStats) {
statsFlag = 1;
}
String[] modelInfos = XgboostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag);
return modelInfos;
String[][] modelInfos = new String[1][];
ErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, featureMap, statsFlag, modelInfos));
return modelInfos[0];
}
/**
@@ -326,8 +354,9 @@ public final class Booster {
* @throws FileNotFoundException
* @throws UnsupportedEncodingException
* @throws IOException
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public void dumpModel(String modelPath, boolean withStats) throws FileNotFoundException, UnsupportedEncodingException, IOException {
public void dumpModel(String modelPath, boolean withStats) throws FileNotFoundException, UnsupportedEncodingException, IOException, XgboostError {
File tf = new File(modelPath);
FileOutputStream out = new FileOutputStream(tf);
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
@@ -352,8 +381,9 @@ public final class Booster {
* @throws FileNotFoundException
* @throws UnsupportedEncodingException
* @throws IOException
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public void dumpModel(String modelPath, String featureMap, boolean withStats) throws FileNotFoundException, UnsupportedEncodingException, IOException {
public void dumpModel(String modelPath, String featureMap, boolean withStats) throws FileNotFoundException, UnsupportedEncodingException, IOException, XgboostError {
File tf = new File(modelPath);
FileOutputStream out = new FileOutputStream(tf);
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(out, "UTF-8"));
@@ -372,8 +402,9 @@ public final class Booster {
/**
* get importance of each feature
* @return featureMap key: feature index, value: feature importance score
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public Map<String, Integer> getFeatureScore() {
public Map<String, Integer> getFeatureScore() throws XgboostError {
String[] modelInfos = getDumpInfo(false);
Map<String, Integer> featureScore = new HashMap<>();
for(String tree : modelInfos) {
@@ -400,8 +431,9 @@ public final class Booster {
* get importance of each feature
* @param featureMap file to save dumped model info
* @return featureMap key: feature index, value: feature importance score
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public Map<String, Integer> getFeatureScore(String featureMap) {
public Map<String, Integer> getFeatureScore(String featureMap) throws XgboostError {
String[] modelInfos = getDumpInfo(featureMap, false);
Map<String, Integer> featureScore = new HashMap<>();
for(String tree : modelInfos) {

View File

@@ -18,6 +18,8 @@ package org.dmlc.xgboost4j;
import java.io.IOException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.dmlc.xgboost4j.util.ErrorHandle;
import org.dmlc.xgboost4j.util.XgboostError;
import org.dmlc.xgboost4j.util.Initializer;
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
@@ -50,9 +52,12 @@ public class DMatrix {
/**
* init DMatrix from file (svmlight format)
* @param dataPath
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public DMatrix(String dataPath) {
handle = XgboostJNI.XGDMatrixCreateFromFile(dataPath, 1);
public DMatrix(String dataPath) throws XgboostError {
long[] out = new long[1];
ErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromFile(dataPath, 1, out));
handle = out[0];
}
/**
@@ -61,17 +66,20 @@ public class DMatrix {
* @param indices Indices (colIndexs for CSR or rowIndexs for CSC)
* @param data non zero values (sequence by row for CSR or by col for CSC)
* @param st sparse matrix type (CSR or CSC)
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) {
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XgboostError {
long[] out = new long[1];
if(st == SparseType.CSR) {
handle = XgboostJNI.XGDMatrixCreateFromCSR(headers, indices, data);
ErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromCSR(headers, indices, data, out));
}
else if(st == SparseType.CSC) {
handle = XgboostJNI.XGDMatrixCreateFromCSC(headers, indices, data);
ErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromCSC(headers, indices, data, out));
}
else {
throw new UnknownError("unknow sparsetype");
}
handle = out[0];
}
/**
@@ -79,10 +87,13 @@ public class DMatrix {
* @param data data values
* @param nrow number of rows
* @param ncol number of columns
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public DMatrix(float[] data, int nrow, int ncol) {
handle = XgboostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f);
}
public DMatrix(float[] data, int nrow, int ncol) throws XgboostError {
long[] out = new long[1];
ErrorHandle.checkCall(XgboostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f, out));
handle = out[0];
}
/**
* used for DMatrix slice
@@ -98,33 +109,36 @@ public class DMatrix {
* set label of dmatrix
* @param labels
*/
public void setLabel(float[] labels) {
XgboostJNI.XGDMatrixSetFloatInfo(handle, "label", labels);
public void setLabel(float[] labels) throws XgboostError {
ErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "label", labels));
}
/**
* set weight of each instance
* @param weights
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public void setWeight(float[] weights) {
XgboostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights);
public void setWeight(float[] weights) throws XgboostError {
ErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights));
}
/**
* if specified, xgboost will start from this init margin
* can be used to specify initial prediction to boost from
* @param baseMargin
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public void setBaseMargin(float[] baseMargin) {
XgboostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin);
public void setBaseMargin(float[] baseMargin) throws XgboostError {
ErrorHandle.checkCall(XgboostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin));
}
/**
* if specified, xgboost will start from this init margin
* can be used to specify initial prediction to boost from
* @param baseMargin
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public void setBaseMargin(float[][] baseMargin) {
public void setBaseMargin(float[][] baseMargin) throws XgboostError {
float[] flattenMargin = flatten(baseMargin);
setBaseMargin(flattenMargin);
}
@@ -132,42 +146,48 @@ public class DMatrix {
/**
* Set group sizes of DMatrix (used for ranking)
* @param group
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public void setGroup(int[] group) {
XgboostJNI.XGDMatrixSetGroup(handle, group);
public void setGroup(int[] group) throws XgboostError {
ErrorHandle.checkCall(XgboostJNI.XGDMatrixSetGroup(handle, group));
}
private float[] getFloatInfo(String field) {
float[] infos = XgboostJNI.XGDMatrixGetFloatInfo(handle, field);
return infos;
private float[] getFloatInfo(String field) throws XgboostError {
float[][] infos = new float[1][];
ErrorHandle.checkCall(XgboostJNI.XGDMatrixGetFloatInfo(handle, field, infos));
return infos[0];
}
private int[] getIntInfo(String field) {
int[] infos = XgboostJNI.XGDMatrixGetUIntInfo(handle, field);
return infos;
private int[] getIntInfo(String field) throws XgboostError {
int[][] infos = new int[1][];
ErrorHandle.checkCall(XgboostJNI.XGDMatrixGetUIntInfo(handle, field, infos));
return infos[0];
}
/**
* get label values
* @return label
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public float[] getLabel() {
public float[] getLabel() throws XgboostError {
return getFloatInfo("label");
}
/**
* get weight of the DMatrix
* @return weights
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public float[] getWeight() {
public float[] getWeight() throws XgboostError {
return getFloatInfo("weight");
}
/**
* get base margin of the DMatrix
* @return base margin
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public float[] getBaseMargin() {
public float[] getBaseMargin() throws XgboostError {
return getFloatInfo("base_margin");
}
@@ -175,9 +195,12 @@ public class DMatrix {
* Slice the DMatrix and return a new DMatrix that only contains `rowIndex`.
* @param rowIndex
* @return sliced new DMatrix
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public DMatrix slice(int[] rowIndex) {
long sHandle = XgboostJNI.XGDMatrixSliceDMatrix(handle, rowIndex);
public DMatrix slice(int[] rowIndex) throws XgboostError {
long[] out = new long[1];
ErrorHandle.checkCall(XgboostJNI.XGDMatrixSliceDMatrix(handle, rowIndex, out));
long sHandle = out[0];
DMatrix sMatrix = new DMatrix(sHandle);
return sMatrix;
}
@@ -185,9 +208,12 @@ public class DMatrix {
/**
* get the row number of DMatrix
* @return number of rows
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public long rowNum() {
return XgboostJNI.XGDMatrixNumRow(handle);
public long rowNum() throws XgboostError {
long[] rowNum = new long[1];
ErrorHandle.checkCall(XgboostJNI.XGDMatrixNumRow(handle,rowNum));
return rowNum[0];
}
/**

View File

@@ -37,8 +37,9 @@ public class CVPack {
* @param dtrain train data
* @param dtest test data
* @param params parameters
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public CVPack(DMatrix dtrain, DMatrix dtest, Iterable<Map.Entry<String, Object>> params) {
public CVPack(DMatrix dtrain, DMatrix dtest, Iterable<Map.Entry<String, Object>> params) throws XgboostError {
dmats = new DMatrix[] {dtrain, dtest};
booster = new Booster(params, dmats);
names = new String[] {"train", "test"};
@@ -49,8 +50,9 @@ public class CVPack {
/**
* update one iteration
* @param iter iteration num
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public void update(int iter) {
public void update(int iter) throws XgboostError {
booster.update(dtrain, iter);
}
@@ -58,8 +60,9 @@ public class CVPack {
* update one iteration
* @param iter iteration num
* @param obj customized objective
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public void update(int iter, IObjective obj) {
public void update(int iter, IObjective obj) throws XgboostError {
booster.update(dtrain, iter, obj);
}
@@ -67,8 +70,9 @@ public class CVPack {
* evaluation
* @param iter iteration num
* @return
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public String eval(int iter) {
public String eval(int iter) throws XgboostError {
return booster.evalSet(dmats, names, iter);
}
@@ -77,8 +81,9 @@ public class CVPack {
* @param iter iteration num
* @param eval customized eval
* @return
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public String eval(int iter, IEvaluation eval) {
public String eval(int iter, IEvaluation eval) throws XgboostError {
return booster.evalSet(dmats, names, iter, eval);
}
}

View File

@@ -0,0 +1,50 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.util;
import java.io.IOException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.dmlc.xgboost4j.wrapper.XgboostJNI;
/**
* error handle for Xgboost
* @author hzx
*/
public class ErrorHandle {
private static final Log logger = LogFactory.getLog(ErrorHandle.class);
//load native library
static {
try {
Initializer.InitXgboost();
} catch (IOException ex) {
logger.error("load native library failed.");
logger.error(ex);
}
}
/**
* check the return value of C API
* @param ret return valud of xgboostJNI C API call
* @throws org.dmlc.xgboost4j.util.XgboostError
*/
public static void checkCall(int ret) throws XgboostError {
if(ret != 0) {
throw new XgboostError(XgboostJNI.XGBGetLastError());
}
}
}

View File

@@ -47,7 +47,7 @@ public class Trainer {
* @return trained booster
*/
public static Booster train(Iterable<Entry<String, Object>> params, DMatrix dtrain, int round,
Iterable<Entry<String, DMatrix>> watchs, IObjective obj, IEvaluation eval) {
Iterable<Entry<String, DMatrix>> watchs, IObjective obj, IEvaluation eval) throws XgboostError {
//collect eval matrixs
String[] evalNames;
@@ -112,7 +112,7 @@ public class Trainer {
* @param eval customized evaluation (set to null if not used)
* @return evaluation history
*/
public static String[] crossValiation(Iterable<Entry<String, Object>> params, DMatrix data, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) {
public static String[] crossValiation(Iterable<Entry<String, Object>> params, DMatrix data, int round, int nfold, String[] metrics, IObjective obj, IEvaluation eval) throws XgboostError {
CVPack[] cvPacks = makeNFold(data, nfold, params, metrics);
String[] evalHist = new String[round];
String[] results = new String[cvPacks.length];
@@ -149,7 +149,7 @@ public class Trainer {
* @param evalMetrics Evaluation metrics
* @return CV package array
*/
public static CVPack[] makeNFold(DMatrix data, int nfold, Iterable<Entry<String, Object>> params, String[] evalMetrics) {
public static CVPack[] makeNFold(DMatrix data, int nfold, Iterable<Entry<String, Object>> params, String[] evalMetrics) throws XgboostError {
List<Integer> samples = genRandPermutationNums(0, (int) data.rowNum());
int step = samples.size()/nfold;
int[] testSlice = new int[step];

View File

@@ -0,0 +1,26 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package org.dmlc.xgboost4j.util;
/**
* custom error class for xgboost
* @author hzx
*/
public class XgboostError extends Exception{
public XgboostError(String message) {
super(message);
}
}

View File

@@ -17,32 +17,34 @@ package org.dmlc.xgboost4j.wrapper;
/**
* xgboost jni wrapper functions for xgboost_wrapper.h
* change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster
* @author hzx
*/
public class XgboostJNI {
public final static native long XGDMatrixCreateFromFile(String fname, int silent);
public final static native long XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data);
public final static native long XGDMatrixCreateFromCSC(long[] colptr, int[] indices, float[] data);
public final static native long XGDMatrixCreateFromMat(float[] data, int nrow, int ncol, float missing);
public final static native long XGDMatrixSliceDMatrix(long handle, int[] idxset);
public final static native void XGDMatrixFree(long handle);
public final static native void XGDMatrixSaveBinary(long handle, String fname, int silent);
public final static native void XGDMatrixSetFloatInfo(long handle, String field, float[] array);
public final static native void XGDMatrixSetUIntInfo(long handle, String field, int[] array);
public final static native void XGDMatrixSetGroup(long handle, int[] group);
public final static native float[] XGDMatrixGetFloatInfo(long handle, String field);
public final static native int[] XGDMatrixGetUIntInfo(long handle, String filed);
public final static native long XGDMatrixNumRow(long handle);
public final static native long XGBoosterCreate(long[] handles);
public final static native void XGBoosterFree(long handle);
public final static native void XGBoosterSetParam(long handle, String name, String value);
public final static native void XGBoosterUpdateOneIter(long handle, int iter, long dtrain);
public final static native void XGBoosterBoostOneIter(long handle, long dtrain, float[] grad, float[] hess);
public final static native String XGBoosterEvalOneIter(long handle, int iter, long[] dmats, String[] evnames);
public final static native float[] XGBoosterPredict(long handle, long dmat, int option_mask, long ntree_limit);
public final static native void XGBoosterLoadModel(long handle, String fname);
public final static native void XGBoosterSaveModel(long handle, String fname);
public final static native void XGBoosterLoadModelFromBuffer(long handle, long buf, long len);
public final static native String XGBoosterGetModelRaw(long handle);
public final static native String[] XGBoosterDumpModel(long handle, String fmap, int with_stats);
public final static native String XGBGetLastError();
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);
public final static native int XGDMatrixCreateFromCSR(long[] indptr, int[] indices, float[] data, long[] out);
public final static native int XGDMatrixCreateFromCSC(long[] colptr, int[] indices, float[] data, long[] out);
public final static native int XGDMatrixCreateFromMat(float[] data, int nrow, int ncol, float missing, long[] out);
public final static native int XGDMatrixSliceDMatrix(long handle, int[] idxset, long[] out);
public final static native int XGDMatrixFree(long handle);
public final static native int XGDMatrixSaveBinary(long handle, String fname, int silent);
public final static native int XGDMatrixSetFloatInfo(long handle, String field, float[] array);
public final static native int XGDMatrixSetUIntInfo(long handle, String field, int[] array);
public final static native int XGDMatrixSetGroup(long handle, int[] group);
public final static native int XGDMatrixGetFloatInfo(long handle, String field, float[][] info);
public final static native int XGDMatrixGetUIntInfo(long handle, String filed, int[][] info);
public final static native int XGDMatrixNumRow(long handle, long[] row);
public final static native int XGBoosterCreate(long[] handles, long[] out);
public final static native int XGBoosterFree(long handle);
public final static native int XGBoosterSetParam(long handle, String name, String value);
public final static native int XGBoosterUpdateOneIter(long handle, int iter, long dtrain);
public final static native int XGBoosterBoostOneIter(long handle, long dtrain, float[] grad, float[] hess);
public final static native int XGBoosterEvalOneIter(long handle, int iter, long[] dmats, String[] evnames, String[] eval_info);
public final static native int XGBoosterPredict(long handle, long dmat, int option_mask, long ntree_limit, float[][] predicts);
public final static native int XGBoosterLoadModel(long handle, String fname);
public final static native int XGBoosterSaveModel(long handle, String fname);
public final static native int XGBoosterLoadModelFromBuffer(long handle, long buf, long len);
public final static native int XGBoosterGetModelRaw(long handle, String[] out_string);
public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats, String[][] out_strings);
}