[jvm-packages] JNI Cosmetics (#2448)

* [jvm-packages] Ensure the native library is loaded once

Previously any class using XGBoostJNI queried NativeLibLoader to make
sure the native library is loaded. This commit moves the initXGBoost
call to XGBoostJNI, effectively delegating the initialization to the class
loader.

Note also, that now XGBoostJNI would NOT suppress an IOException if it
occured in initXGBoost.

* [jvm-packages] Fused JNIErrorHandle with XGBoostJNI

There was no reason for having a separate class.
This commit is contained in:
Sergei Lebedev 2017-06-23 20:49:30 +02:00 committed by Nan Zhu
parent 0e48f87529
commit 91e778c6db
6 changed files with 59 additions and 129 deletions

View File

@ -35,16 +35,6 @@ public class Booster implements Serializable, KryoSerializable {
// handle to the booster.
private long handle = 0;
//load native library
static {
try {
NativeLibLoader.initXGBoost();
} catch (IOException ex) {
logger.error("load native library failed.");
logger.error(ex);
}
}
/**
* Create a new Booster with empty stage.
*
@ -70,7 +60,7 @@ public class Booster implements Serializable, KryoSerializable {
throw new NullPointerException("modelPath : null");
}
Booster ret = new Booster(new HashMap<String, Object>(), new DMatrix[0]);
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModel(ret.handle, modelPath));
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModel(ret.handle, modelPath));
return ret;
}
@ -93,7 +83,7 @@ public class Booster implements Serializable, KryoSerializable {
}
in.close();
Booster ret = new Booster(new HashMap<String, Object>(), new DMatrix[0]);
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(ret.handle,os.toByteArray()));
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(ret.handle,os.toByteArray()));
return ret;
}
@ -105,7 +95,7 @@ public class Booster implements Serializable, KryoSerializable {
* @throws XGBoostError native error
*/
public final void setParam(String key, Object value) throws XGBoostError {
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSetParam(handle, key, value.toString()));
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSetParam(handle, key, value.toString()));
}
/**
@ -130,7 +120,7 @@ public class Booster implements Serializable, KryoSerializable {
* @throws XGBoostError native error
*/
public void update(DMatrix dtrain, int iter) throws XGBoostError {
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
}
/**
@ -159,7 +149,7 @@ public class Booster implements Serializable, KryoSerializable {
throw new AssertionError(String.format("grad/hess length mismatch %s / %s", grad.length,
hess.length));
}
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterBoostOneIter(handle,
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterBoostOneIter(handle,
dtrain.getHandle(), grad, hess));
}
@ -175,7 +165,7 @@ public class Booster implements Serializable, KryoSerializable {
public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError {
long[] handles = dmatrixsToHandles(evalMatrixs);
String[] evalInfo = new String[1];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames,
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames,
evalInfo));
return evalInfo[0];
}
@ -243,7 +233,7 @@ public class Booster implements Serializable, KryoSerializable {
optionMask = 2;
}
float[][] rawPredicts = new float[1][];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask,
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask,
treeLimit, rawPredicts));
int row = (int) data.rowNum();
int col = rawPredicts[0].length / row;
@ -309,7 +299,7 @@ public class Booster implements Serializable, KryoSerializable {
* @param modelPath model path
*/
public void saveModel(String modelPath) throws XGBoostError{
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveModel(handle, modelPath));
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveModel(handle, modelPath));
}
/**
@ -349,7 +339,7 @@ public class Booster implements Serializable, KryoSerializable {
format = "text";
}
String[][] modelInfos = new String[1][];
JNIErrorHandle.checkCall(
XGBoostJNI.checkCall(
XGBoostJNI.XGBoosterDumpModelEx(handle, featureMap, statsFlag, format, modelInfos));
return modelInfos[0];
}
@ -397,7 +387,7 @@ public class Booster implements Serializable, KryoSerializable {
statsFlag = 1;
}
String[][] modelInfos = new String[1][];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterDumpModelEx(handle, "", statsFlag, "text",
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterDumpModelEx(handle, "", statsFlag, "text",
modelInfos));
return modelInfos[0];
}
@ -409,7 +399,7 @@ public class Booster implements Serializable, KryoSerializable {
*/
public byte[] toByteArray() throws XGBoostError {
byte[][] bytes = new byte[1][];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterGetModelRaw(this.handle, bytes));
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterGetModelRaw(this.handle, bytes));
return bytes[0];
}
@ -421,7 +411,7 @@ public class Booster implements Serializable, KryoSerializable {
*/
int loadRabitCheckpoint() throws XGBoostError {
int[] out = new int[1];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
return out[0];
}
@ -431,7 +421,7 @@ public class Booster implements Serializable, KryoSerializable {
* @throws XGBoostError
*/
void saveRabitCheckpoint() throws XGBoostError {
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
}
/**
@ -445,7 +435,7 @@ public class Booster implements Serializable, KryoSerializable {
handles = dmatrixsToHandles(cacheMats);
}
long[] out = new long[1];
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterCreate(handles, out));
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterCreate(handles, out));
handle = out[0];
}
@ -479,7 +469,7 @@ public class Booster implements Serializable, KryoSerializable {
try {
this.init(null);
byte[] bytes = (byte[])in.readObject();
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
} catch (XGBoostError ex) {
ex.printStackTrace();
logger.error(ex.getMessage());
@ -521,7 +511,7 @@ public class Booster implements Serializable, KryoSerializable {
System.out.println("==== the size of the object: " + serObjSize);
byte[] bytes = new byte[serObjSize];
input.readBytes(bytes);
JNIErrorHandle.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterLoadModelFromBuffer(this.handle, bytes));
} catch (XGBoostError ex) {
ex.printStackTrace();
logger.error(ex.getMessage());

View File

@ -15,12 +15,8 @@
*/
package ml.dmlc.xgboost4j.java;
import java.io.IOException;
import java.util.Iterator;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import ml.dmlc.xgboost4j.LabeledPoint;
/**
@ -29,19 +25,8 @@ import ml.dmlc.xgboost4j.LabeledPoint;
* @author hzx
*/
public class DMatrix {
private static final Log logger = LogFactory.getLog(DMatrix.class);
protected long handle = 0;
//load native library
static {
try {
NativeLibLoader.initXGBoost();
} catch (IOException ex) {
logger.error("load native library failed.");
logger.error(ex);
}
}
/**
* sparse matrix type (CSR or CSC)
*/
@ -65,7 +50,7 @@ public class DMatrix {
int batchSize = 32 << 10;
Iterator<DataBatch> batchIter = new DataBatch.BatchIterator(iter, batchSize);
long[] out = new long[1];
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromDataIter(batchIter, cacheInfo, out));
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromDataIter(batchIter, cacheInfo, out));
handle = out[0];
}
@ -80,7 +65,7 @@ public class DMatrix {
throw new NullPointerException("dataPath: null");
}
long[] out = new long[1];
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromFile(dataPath, 1, out));
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromFile(dataPath, 1, out));
handle = out[0];
}
@ -96,9 +81,9 @@ public class DMatrix {
public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError {
long[] out = new long[1];
if (st == SparseType.CSR) {
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(headers, indices, data, 0, out));
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(headers, indices, data, 0, out));
} else if (st == SparseType.CSC) {
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSCEx(headers, indices, data, 0, out));
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromCSCEx(headers, indices, data, 0, out));
} else {
throw new UnknownError("unknow sparsetype");
}
@ -119,10 +104,10 @@ public class DMatrix {
throws XGBoostError {
long[] out = new long[1];
if (st == SparseType.CSR) {
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(headers, indices, data,
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(headers, indices, data,
shapeParam, out));
} else if (st == SparseType.CSC) {
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSCEx(headers, indices, data,
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromCSCEx(headers, indices, data,
shapeParam, out));
} else {
throw new UnknownError("unknow sparsetype");
@ -140,7 +125,7 @@ public class DMatrix {
*/
public DMatrix(float[] data, int nrow, int ncol) throws XGBoostError {
long[] out = new long[1];
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f, out));
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f, out));
handle = out[0];
}
@ -153,7 +138,7 @@ public class DMatrix {
*/
public DMatrix(float[] data, int nrow, int ncol, float missing) throws XGBoostError {
long[] out = new long[1];
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, missing, out));
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, missing, out));
handle = out[0];
}
@ -172,7 +157,7 @@ public class DMatrix {
* @throws XGBoostError native error
*/
public void setLabel(float[] labels) throws XGBoostError {
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "label", labels));
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "label", labels));
}
/**
@ -182,7 +167,7 @@ public class DMatrix {
* @throws XGBoostError native error
*/
public void setWeight(float[] weights) throws XGBoostError {
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights));
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights));
}
/**
@ -193,7 +178,7 @@ public class DMatrix {
* @throws XGBoostError native error
*/
public void setBaseMargin(float[] baseMargin) throws XGBoostError {
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin));
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin));
}
/**
@ -215,18 +200,18 @@ public class DMatrix {
* @throws XGBoostError native error
*/
public void setGroup(int[] group) throws XGBoostError {
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetGroup(handle, group));
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetGroup(handle, group));
}
private float[] getFloatInfo(String field) throws XGBoostError {
float[][] infos = new float[1][];
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixGetFloatInfo(handle, field, infos));
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixGetFloatInfo(handle, field, infos));
return infos[0];
}
private int[] getIntInfo(String field) throws XGBoostError {
int[][] infos = new int[1][];
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixGetUIntInfo(handle, field, infos));
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixGetUIntInfo(handle, field, infos));
return infos[0];
}
@ -269,7 +254,7 @@ public class DMatrix {
*/
public DMatrix slice(int[] rowIndex) throws XGBoostError {
long[] out = new long[1];
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSliceDMatrix(handle, rowIndex, out));
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSliceDMatrix(handle, rowIndex, out));
long sHandle = out[0];
DMatrix sMatrix = new DMatrix(sHandle);
return sMatrix;
@ -283,7 +268,7 @@ public class DMatrix {
*/
public long rowNum() throws XGBoostError {
long[] rowNum = new long[1];
JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixNumRow(handle, rowNum));
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixNumRow(handle, rowNum));
return rowNum[0];
}

View File

@ -1,52 +0,0 @@
/*
Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java;
import java.io.IOException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
/**
* Error handle for Xgboost.
*/
class JNIErrorHandle {
private static final Log logger = LogFactory.getLog(DMatrix.class);
//load native library
static {
try {
NativeLibLoader.initXGBoost();
} catch (IOException ex) {
logger.error("load native library failed.");
logger.error(ex);
}
}
/**
* Check the return value of C API.
*
* @param ret return valud of xgboostJNI C API call
* @throws XGBoostError native error
*/
static void checkCall(int ret) throws XGBoostError {
if (ret != 0) {
throw new XGBoostError(XGBoostJNI.XGBGetLastError());
}
}
}

View File

@ -21,7 +21,6 @@ import java.lang.reflect.Field;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
/**
* class to load native library
*
@ -35,7 +34,7 @@ class NativeLibLoader {
private static final String nativeResourcePath = "/lib/";
private static final String[] libNames = new String[]{"xgboost4j"};
public static synchronized void initXGBoost() throws IOException {
static synchronized void initXGBoost() throws IOException {
if (!initialized) {
for (String libName : libNames) {
smartLoad(libName);

View File

@ -1,30 +1,14 @@
package ml.dmlc.xgboost4j.java;
import java.io.IOException;
import java.io.Serializable;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
/**
* Rabit global class for synchronization.
*/
public class Rabit {
private static final Log logger = LogFactory.getLog(DMatrix.class);
//load native library
static {
try {
NativeLibLoader.initXGBoost();
} catch (IOException ex) {
logger.error("load native library failed.");
logger.error(ex);
}
}
public enum OpType implements Serializable {
MAX(0), MIN(1), SUM(2), BITWISE_OR(3);

View File

@ -15,9 +15,11 @@
*/
package ml.dmlc.xgboost4j.java;
import java.nio.ByteBuffer;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
/**
* xgboost JNI functions
* change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster
@ -25,6 +27,28 @@ import java.nio.ByteBuffer;
* @author hzx
*/
class XGBoostJNI {
private static final Log logger = LogFactory.getLog(DMatrix.class);
static {
try {
NativeLibLoader.initXGBoost();
} catch (Exception ex) {
logger.error("Failed to load native library", ex);
throw new RuntimeException(ex);
}
}
/**
* Check the return code of the JNI call.
*
* @throws XGBoostError if the call failed.
*/
static void checkCall(int ret) throws XGBoostError {
if (ret != 0) {
throw new XGBoostError(XGBGetLastError());
}
}
public final static native String XGBGetLastError();
public final static native int XGDMatrixCreateFromFile(String fname, int silent, long[] out);