diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java index 672d538ea..92438121d 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java @@ -35,16 +35,6 @@ public class Booster implements Serializable, KryoSerializable { // handle to the booster. private long handle = 0; - //load native library - static { - try { - NativeLibLoader.initXGBoost(); - } catch (IOException ex) { - logger.error("load native library failed."); - logger.error(ex); - } - } - /** * Create a new Booster with empty stage. * @@ -70,7 +60,7 @@ public class Booster implements Serializable, KryoSerializable { throw new NullPointerException("modelPath : null"); } Booster ret = new Booster(new HashMap(), new DMatrix[0]); - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModel(ret.handle, modelPath)); + XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModel(ret.handle, modelPath)); return ret; } @@ -93,7 +83,7 @@ public class Booster implements Serializable, KryoSerializable { } in.close(); Booster ret = new Booster(new HashMap(), new DMatrix[0]); - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(ret.handle,os.toByteArray())); + XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(ret.handle,os.toByteArray())); return ret; } @@ -105,7 +95,7 @@ public class Booster implements Serializable, KryoSerializable { * @throws XGBoostError native error */ public final void setParam(String key, Object value) throws XGBoostError { - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSetParam(handle, key, value.toString())); + XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSetParam(handle, key, value.toString())); } /** @@ -130,7 +120,7 @@ public class Booster implements Serializable, KryoSerializable { * @throws XGBoostError native error */ public void update(DMatrix dtrain, int iter) throws XGBoostError { - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle())); + XGBoostJNI.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle())); } /** @@ -159,7 +149,7 @@ public class Booster implements Serializable, KryoSerializable { throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length, hess.length)); } - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterBoostOneIter(handle, + XGBoostJNI.checkCall(XGBoostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad, hess)); } @@ -175,7 +165,7 @@ public class Booster implements Serializable, KryoSerializable { public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError { long[] handles = dmatrixsToHandles(evalMatrixs); String[] evalInfo = new String[1]; - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames, + XGBoostJNI.checkCall(XGBoostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames, evalInfo)); return evalInfo[0]; } @@ -243,7 +233,7 @@ public class Booster implements Serializable, KryoSerializable { optionMask = 2; } float[][] rawPredicts = new float[1][]; - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask, + XGBoostJNI.checkCall(XGBoostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask, treeLimit, rawPredicts)); int row = (int) data.rowNum(); int col = rawPredicts[0].length / row; @@ -309,7 +299,7 @@ public class Booster implements Serializable, KryoSerializable { * @param modelPath model path */ public void saveModel(String modelPath) throws XGBoostError{ - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveModel(handle, modelPath)); + XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveModel(handle, modelPath)); } /** @@ -349,7 +339,7 @@ public class Booster implements Serializable, KryoSerializable { format = "text"; } String[][] modelInfos = new String[1][]; - JNIErrorHandle.checkCall( + XGBoostJNI.checkCall( XGBoostJNI.XGBoosterDumpModelEx(handle, featureMap, statsFlag, format, modelInfos)); return modelInfos[0]; } @@ -397,7 +387,7 @@ public class Booster implements Serializable, KryoSerializable { statsFlag = 1; } String[][] modelInfos = new String[1][]; - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModelEx(handle, "", statsFlag, "text", + XGBoostJNI.checkCall(XGBoostJNI.XGBoosterDumpModelEx(handle, "", statsFlag, "text", modelInfos)); return modelInfos[0]; } @@ -409,7 +399,7 @@ public class Booster implements Serializable, KryoSerializable { */ public byte[] toByteArray() throws XGBoostError { byte[][] bytes = new byte[1][]; - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterGetModelRaw(this.handle, bytes)); + XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetModelRaw(this.handle, bytes)); return bytes[0]; } @@ -421,7 +411,7 @@ public class Booster implements Serializable, KryoSerializable { */ int loadRabitCheckpoint() throws XGBoostError { int[] out = new int[1]; - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out)); + XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out)); return out[0]; } @@ -431,7 +421,7 @@ public class Booster implements Serializable, KryoSerializable { * @throws XGBoostError */ void saveRabitCheckpoint() throws XGBoostError { - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle)); + XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle)); } /** @@ -445,7 +435,7 @@ public class Booster implements Serializable, KryoSerializable { handles = dmatrixsToHandles(cacheMats); } long[] out = new long[1]; - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterCreate(handles, out)); + XGBoostJNI.checkCall(XGBoostJNI.XGBoosterCreate(handles, out)); handle = out[0]; } @@ -479,7 +469,7 @@ public class Booster implements Serializable, KryoSerializable { try { this.init(null); byte[] bytes = (byte[])in.readObject(); - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes)); + XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes)); } catch (XGBoostError ex) { ex.printStackTrace(); logger.error(ex.getMessage()); @@ -521,7 +511,7 @@ public class Booster implements Serializable, KryoSerializable { System.out.println("==== the size of the object: " + serObjSize); byte[] bytes = new byte[serObjSize]; input.readBytes(bytes); - JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes)); + XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes)); } catch (XGBoostError ex) { ex.printStackTrace(); logger.error(ex.getMessage()); diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java index b2b55597c..e0fc1247c 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DMatrix.java @@ -15,12 +15,8 @@ */ package ml.dmlc.xgboost4j.java; -import java.io.IOException; import java.util.Iterator; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; - import ml.dmlc.xgboost4j.LabeledPoint; /** @@ -29,19 +25,8 @@ import ml.dmlc.xgboost4j.LabeledPoint; * @author hzx */ public class DMatrix { - private static final Log logger = LogFactory.getLog(DMatrix.class); protected long handle = 0; - //load native library - static { - try { - NativeLibLoader.initXGBoost(); - } catch (IOException ex) { - logger.error("load native library failed."); - logger.error(ex); - } - } - /** * sparse matrix type (CSR or CSC) */ @@ -65,7 +50,7 @@ public class DMatrix { int batchSize = 32 << 10; Iterator batchIter = new DataBatch.BatchIterator(iter, batchSize); long[] out = new long[1]; - JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromDataIter(batchIter, cacheInfo, out)); + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromDataIter(batchIter, cacheInfo, out)); handle = out[0]; } @@ -80,7 +65,7 @@ public class DMatrix { throw new NullPointerException("dataPath: null"); } long[] out = new long[1]; - JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromFile(dataPath, 1, out)); + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromFile(dataPath, 1, out)); handle = out[0]; } @@ -96,9 +81,9 @@ public class DMatrix { public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError { long[] out = new long[1]; if (st == SparseType.CSR) { - JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(headers, indices, data, 0, out)); + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(headers, indices, data, 0, out)); } else if (st == SparseType.CSC) { - JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSCEx(headers, indices, data, 0, out)); + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromCSCEx(headers, indices, data, 0, out)); } else { throw new UnknownError("unknow sparsetype"); } @@ -119,10 +104,10 @@ public class DMatrix { throws XGBoostError { long[] out = new long[1]; if (st == SparseType.CSR) { - JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(headers, indices, data, + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(headers, indices, data, shapeParam, out)); } else if (st == SparseType.CSC) { - JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSCEx(headers, indices, data, + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromCSCEx(headers, indices, data, shapeParam, out)); } else { throw new UnknownError("unknow sparsetype"); @@ -140,7 +125,7 @@ public class DMatrix { */ public DMatrix(float[] data, int nrow, int ncol) throws XGBoostError { long[] out = new long[1]; - JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f, out)); + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f, out)); handle = out[0]; } @@ -153,7 +138,7 @@ public class DMatrix { */ public DMatrix(float[] data, int nrow, int ncol, float missing) throws XGBoostError { long[] out = new long[1]; - JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, missing, out)); + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, missing, out)); handle = out[0]; } @@ -172,7 +157,7 @@ public class DMatrix { * @throws XGBoostError native error */ public void setLabel(float[] labels) throws XGBoostError { - JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "label", labels)); + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "label", labels)); } /** @@ -182,7 +167,7 @@ public class DMatrix { * @throws XGBoostError native error */ public void setWeight(float[] weights) throws XGBoostError { - JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights)); + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights)); } /** @@ -193,7 +178,7 @@ public class DMatrix { * @throws XGBoostError native error */ public void setBaseMargin(float[] baseMargin) throws XGBoostError { - JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin)); + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin)); } /** @@ -215,18 +200,18 @@ public class DMatrix { * @throws XGBoostError native error */ public void setGroup(int[] group) throws XGBoostError { - JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetGroup(handle, group)); + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetGroup(handle, group)); } private float[] getFloatInfo(String field) throws XGBoostError { float[][] infos = new float[1][]; - JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixGetFloatInfo(handle, field, infos)); + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixGetFloatInfo(handle, field, infos)); return infos[0]; } private int[] getIntInfo(String field) throws XGBoostError { int[][] infos = new int[1][]; - JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixGetUIntInfo(handle, field, infos)); + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixGetUIntInfo(handle, field, infos)); return infos[0]; } @@ -269,7 +254,7 @@ public class DMatrix { */ public DMatrix slice(int[] rowIndex) throws XGBoostError { long[] out = new long[1]; - JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSliceDMatrix(handle, rowIndex, out)); + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSliceDMatrix(handle, rowIndex, out)); long sHandle = out[0]; DMatrix sMatrix = new DMatrix(sHandle); return sMatrix; @@ -283,7 +268,7 @@ public class DMatrix { */ public long rowNum() throws XGBoostError { long[] rowNum = new long[1]; - JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixNumRow(handle, rowNum)); + XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixNumRow(handle, rowNum)); return rowNum[0]; } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/JNIErrorHandle.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/JNIErrorHandle.java deleted file mode 100644 index c8888154b..000000000 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/JNIErrorHandle.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - Copyright (c) 2014 by Contributors - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ -package ml.dmlc.xgboost4j.java; - -import java.io.IOException; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; - -/** - * Error handle for Xgboost. - */ -class JNIErrorHandle { - - private static final Log logger = LogFactory.getLog(DMatrix.class); - - //load native library - static { - try { - NativeLibLoader.initXGBoost(); - } catch (IOException ex) { - logger.error("load native library failed."); - logger.error(ex); - } - } - - /** - * Check the return value of C API. - * - * @param ret return valud of xgboostJNI C API call - * @throws XGBoostError native error - */ - static void checkCall(int ret) throws XGBoostError { - if (ret != 0) { - throw new XGBoostError(XGBoostJNI.XGBGetLastError()); - } - } - -} diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/NativeLibLoader.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/NativeLibLoader.java index 3fa05e0ed..1b5438cd5 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/NativeLibLoader.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/NativeLibLoader.java @@ -21,7 +21,6 @@ import java.lang.reflect.Field; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; - /** * class to load native library * @@ -35,7 +34,7 @@ class NativeLibLoader { private static final String nativeResourcePath = "/lib/"; private static final String[] libNames = new String[]{"xgboost4j"}; - public static synchronized void initXGBoost() throws IOException { + static synchronized void initXGBoost() throws IOException { if (!initialized) { for (String libName : libNames) { smartLoad(libName); diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java index 6e996494b..710165d4c 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Rabit.java @@ -1,30 +1,14 @@ package ml.dmlc.xgboost4j.java; -import java.io.IOException; import java.io.Serializable; import java.nio.ByteBuffer; import java.nio.ByteOrder; -import java.util.Arrays; import java.util.Map; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; - /** * Rabit global class for synchronization. */ public class Rabit { - private static final Log logger = LogFactory.getLog(DMatrix.class); - //load native library - static { - try { - NativeLibLoader.initXGBoost(); - } catch (IOException ex) { - logger.error("load native library failed."); - logger.error(ex); - } - } - public enum OpType implements Serializable { MAX(0), MIN(1), SUM(2), BITWISE_OR(3); diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java index 4d0f31dd1..6922f26b0 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java @@ -15,9 +15,11 @@ */ package ml.dmlc.xgboost4j.java; - import java.nio.ByteBuffer; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + /** * xgboost JNI functions * change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster @@ -25,6 +27,28 @@ import java.nio.ByteBuffer; * @author hzx */ class XGBoostJNI { + private static final Log logger = LogFactory.getLog(DMatrix.class); + + static { + try { + NativeLibLoader.initXGBoost(); + } catch (Exception ex) { + logger.error("Failed to load native library", ex); + throw new RuntimeException(ex); + } + } + + /** + * Check the return code of the JNI call. + * + * @throws XGBoostError if the call failed. + */ + static void checkCall(int ret) throws XGBoostError { + if (ret != 0) { + throw new XGBoostError(XGBGetLastError()); + } + } + public final static native String XGBGetLastError(); public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out); @@ -104,4 +128,4 @@ class XGBoostJNI { // This JNI function does not support the callback function for data preparation yet. final static native int RabitAllreduce(ByteBuffer sendrecvbuf, int count, int enum_dtype, int enum_op); -} +}