From 5c9e50148ac889afd29cc001c1824a151b622723 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 2 Mar 2016 22:12:17 -0800 Subject: [PATCH] [JVM-PKG] Update JNI to include rabit codes --- .../main/java/ml/dmlc/xgboost4j/DMatrix.java | 2 +- .../ml/dmlc/xgboost4j/JNIErrorHandle.java | 1 + .../ml/dmlc/xgboost4j/JavaBoosterImpl.java | 21 + .../main/java/ml/dmlc/xgboost4j/Rabit.java | 93 +++ .../main/java/ml/dmlc/xgboost4j/XGBoost.java | 20 +- .../java/ml/dmlc/xgboost4j/XgboostJNI.java | 15 +- .../xgboost4j/src/native/xgboost4j.cpp | 657 ++++++++++-------- jvm-packages/xgboost4j/src/native/xgboost4j.h | 80 +++ 8 files changed, 609 insertions(+), 280 deletions(-) create mode 100644 jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/Rabit.java diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/DMatrix.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/DMatrix.java index 4b498caf1..1a50137cb 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/DMatrix.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/DMatrix.java @@ -21,7 +21,7 @@ import org.apache.commons.logging.LogFactory; import java.io.IOException; /** - * DMatrix for xgboost, similar to the python wrapper xgboost.py + * DMatrix for xgboost. * * @author hzx */ diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/JNIErrorHandle.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/JNIErrorHandle.java index 06474dbb4..741151d51 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/JNIErrorHandle.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/JNIErrorHandle.java @@ -48,4 +48,5 @@ class JNIErrorHandle { throw new XGBoostError(XgboostJNI.XGBGetLastError()); } } + } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/JavaBoosterImpl.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/JavaBoosterImpl.java index 321b7fead..9c1659976 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/JavaBoosterImpl.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/JavaBoosterImpl.java @@ -441,6 +441,27 @@ class JavaBoosterImpl implements Booster { return featureScore; } + /** + * Load the booster model from thread-local rabit checkpoint. + * This is only used in distributed training. + * @return the stored version number of the checkpoint. + * @throws XGBoostError + */ + int loadRabitCheckpoint() throws XGBoostError { + int[] out = new int[1]; + JNIErrorHandle.checkCall(XgboostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out)); + return out[0]; + } + + /** + * Save the booster model into thread-local rabit checkpoint. + * This is only used in distributed training. + * @throws XGBoostError + */ + void saveRabitCheckpoint() throws XGBoostError { + JNIErrorHandle.checkCall(XgboostJNI.XGBoosterSaveRabitCheckpoint(this.handle)); + } + /** * transfer DMatrix array to handle array (used for native functions) * diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/Rabit.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/Rabit.java new file mode 100644 index 000000000..3c8bc4142 --- /dev/null +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/Rabit.java @@ -0,0 +1,93 @@ +package ml.dmlc.xgboost4j; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import java.util.Map; +import java.io.IOException; + +/** + * Rabit global class for synchronization. + */ +public class Rabit { + private static final Log logger = LogFactory.getLog(DMatrix.class); + //load native library + static { + try { + NativeLibLoader.initXgBoost(); + } catch (IOException ex) { + logger.error("load native library failed."); + logger.error(ex); + } + } + + private static void checkCall(int ret) throws XGBoostError { + if (ret != 0) { + throw new XGBoostError(XgboostJNI.XGBGetLastError()); + } + } + + /** + * Initialize the rabit library on current working thread. + * @param envs The additional environment variables to pass to rabit. + * @throws XGBoostError + */ + public static void init(Map envs) throws XGBoostError { + String[] args = new String[envs.size()]; + int idx = 0; + for (java.util.Map.Entry e : envs.entrySet()) { + args[idx++] = e.getKey() + '=' + e.getValue(); + } + checkCall(XgboostJNI.RabitInit(args)); + } + + /** + * Shutdown the rabit engine in current working thread, equals to finalize. + * @throws XGBoostError + */ + public static void shutdown() throws XGBoostError { + checkCall(XgboostJNI.RabitFinalize()); + } + + /** + * Print the message on rabit tracker. + * @param msg + * @throws XGBoostError + */ + public static void trackerPrint(String msg) throws XGBoostError { + checkCall(XgboostJNI.RabitTrackerPrint(msg)); + } + + /** + * Get version number of current stored model in the thread. + * which means how many calls to CheckPoint we made so far. + * @return version Number. + * @throws XGBoostError + */ + public static int versionNumber() throws XGBoostError { + int[] out = new int[1]; + checkCall(XgboostJNI.RabitVersionNumber(out)); + return out[0]; + } + + /** + * get rank of current thread. + * @return the rank. + * @throws XGBoostError + */ + public static int getRank() throws XGBoostError { + int[] out = new int[1]; + checkCall(XgboostJNI.RabitGetRank(out)); + return out[0]; + } + + /** + * get world size of current job. + * @return the worldsize + * @throws XGBoostError + */ + public static int getWorldSize() throws XGBoostError { + int[] out = new int[1]; + checkCall(XgboostJNI.RabitGetWorldSize(out)); + return out[0]; + } +} diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XGBoost.java index cea4ae5bf..4214d8f13 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XGBoost.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XGBoost.java @@ -72,14 +72,20 @@ public class XGBoost { } //initialize booster - Booster booster = new JavaBoosterImpl(params, allMats); + JavaBoosterImpl booster = new JavaBoosterImpl(params, allMats); + + int version = booster.loadRabitCheckpoint(); //begin to train - for (int iter = 0; iter < round; iter++) { - if (obj != null) { - booster.update(dtrain, obj); - } else { - booster.update(dtrain, iter); + for (int iter = version / 2; iter < round; iter++) { + if (version % 2 == 0) { + if (obj != null) { + booster.update(dtrain, obj); + } else { + booster.update(dtrain, iter); + } + booster.saveRabitCheckpoint(); + version += 1; } //evaluation @@ -92,6 +98,8 @@ public class XGBoost { } logger.info(evalInfo); } + booster.saveRabitCheckpoint(); + version += 1; } return booster; } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XgboostJNI.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XgboostJNI.java index 10ba1802b..c26968b54 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XgboostJNI.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/XgboostJNI.java @@ -16,7 +16,7 @@ package ml.dmlc.xgboost4j; /** - * xgboost jni wrapper functions for xgboost_wrapper.h + * xgboost JNI functions * change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster * * @author hzx @@ -80,4 +80,17 @@ class XgboostJNI { public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats, String[][] out_strings); + + public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string); + public final static native int XGBoosterSetAttr(long handle, String key, String value); + public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version); + public final static native int XGBoosterSaveRabitCheckpoint(long handle); + + // rabit functions + public final static native int RabitInit(String[] args); + public final static native int RabitFinalize(); + public final static native int RabitTrackerPrint(String msg); + public final static native int RabitGetRank(int[] out); + public final static native int RabitGetWorldSize(int[] out); + public final static native int RabitVersionNumber(int[] out); } diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index 0d976a33f..8556be4a5 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -1,45 +1,50 @@ /* - Copyright (c) 2014 by Contributors - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ + Copyright (c) 2014 by Contributors + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at -#include "xgboost/c_api.h" -#include "xgboost4j.h" + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +#include +#include "./xgboost4j.h" #include +#include +#include //helper functions //set handle void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) { - long out[1]; - out[0] = (long) handle; - jenv->SetLongArrayRegion(jhandle, 0, 1, (const jlong*) out); + long out = (long) handle; + jenv->SetLongArrayRegion(jhandle, 0, 1, &out); } JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError (JNIEnv *jenv, jclass jcls) { - jstring jresult = 0 ; - const char* result = XGBGetLastError(); - if (result) jresult = jenv->NewStringUTF(result); - return jresult; + jstring jresult = 0; + const char* result = XGBGetLastError(); + if (result != NULL) { + jresult = jenv->NewStringUTF(result); + } + return jresult; } JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile (JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent, jlongArray jout) { - DMatrixHandle result; - const char* fname = jenv->GetStringUTFChars(jfname, 0); - int ret = XGDMatrixCreateFromFile(fname, jsilent, &result); - if (fname) jenv->ReleaseStringUTFChars(jfname, fname); - setHandle(jenv, jout, result); - return ret; + DMatrixHandle result; + const char* fname = jenv->GetStringUTFChars(jfname, 0); + int ret = XGDMatrixCreateFromFile(fname, jsilent, &result); + if (fname) { + jenv->ReleaseStringUTFChars(jfname, fname); + } + setHandle(jenv, jout, result); + return ret; } /* @@ -49,19 +54,19 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSR (JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) { - DMatrixHandle result; - jlong* indptr = jenv->GetLongArrayElements(jindptr, 0); - jint* indices = jenv->GetIntArrayElements(jindices, 0); - jfloat* data = jenv->GetFloatArrayElements(jdata, 0); - bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr); - bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata); - int ret = (jint) XGDMatrixCreateFromCSR((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, &result); - setHandle(jenv, jout, result); - //Release - jenv->ReleaseLongArrayElements(jindptr, indptr, 0); - jenv->ReleaseIntArrayElements(jindices, indices, 0); - jenv->ReleaseFloatArrayElements(jdata, data, 0); - return ret; + DMatrixHandle result; + jlong* indptr = jenv->GetLongArrayElements(jindptr, 0); + jint* indices = jenv->GetIntArrayElements(jindices, 0); + jfloat* data = jenv->GetFloatArrayElements(jdata, 0); + bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr); + bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata); + int ret = (jint) XGDMatrixCreateFromCSR((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, &result); + setHandle(jenv, jout, result); + //Release + jenv->ReleaseLongArrayElements(jindptr, indptr, 0); + jenv->ReleaseIntArrayElements(jindices, indices, 0); + jenv->ReleaseFloatArrayElements(jdata, data, 0); + return ret; } /* @@ -71,21 +76,21 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSR */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSC (JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) { - DMatrixHandle result; - jlong* indptr = jenv->GetLongArrayElements(jindptr, NULL); - jint* indices = jenv->GetIntArrayElements(jindices, 0); - jfloat* data = jenv->GetFloatArrayElements(jdata, NULL); - bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr); - bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata); + DMatrixHandle result; + jlong* indptr = jenv->GetLongArrayElements(jindptr, NULL); + jint* indices = jenv->GetIntArrayElements(jindices, 0); + jfloat* data = jenv->GetFloatArrayElements(jdata, NULL); + bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr); + bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata); - int ret = (jint) XGDMatrixCreateFromCSC((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, &result); - setHandle(jenv, jout, result); - //release - jenv->ReleaseLongArrayElements(jindptr, indptr, 0); - jenv->ReleaseIntArrayElements(jindices, indices, 0); - jenv->ReleaseFloatArrayElements(jdata, data, 0); - - return ret; + int ret = (jint) XGDMatrixCreateFromCSC((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, &result); + setHandle(jenv, jout, result); + //release + jenv->ReleaseLongArrayElements(jindptr, indptr, 0); + jenv->ReleaseIntArrayElements(jindices, indices, 0); + jenv->ReleaseFloatArrayElements(jdata, data, 0); + + return ret; } /* @@ -95,15 +100,15 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSC */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromMat (JNIEnv *jenv, jclass jcls, jfloatArray jdata, jint jnrow, jint jncol, jfloat jmiss, jlongArray jout) { - DMatrixHandle result; - jfloat* data = jenv->GetFloatArrayElements(jdata, 0); - bst_ulong nrow = (bst_ulong)jnrow; - bst_ulong ncol = (bst_ulong)jncol; - int ret = (jint) XGDMatrixCreateFromMat((float const *)data, nrow, ncol, jmiss, &result); - setHandle(jenv, jout, result); - //release - jenv->ReleaseFloatArrayElements(jdata, data, 0); - return ret; + DMatrixHandle result; + jfloat* data = jenv->GetFloatArrayElements(jdata, 0); + bst_ulong nrow = (bst_ulong)jnrow; + bst_ulong ncol = (bst_ulong)jncol; + int ret = (jint) XGDMatrixCreateFromMat((float const *)data, nrow, ncol, jmiss, &result); + setHandle(jenv, jout, result); + //release + jenv->ReleaseFloatArrayElements(jdata, data, 0); + return ret; } /* @@ -113,18 +118,18 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromMat */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSliceDMatrix (JNIEnv *jenv, jclass jcls, jlong jhandle, jintArray jindexset, jlongArray jout) { - DMatrixHandle result; - DMatrixHandle handle = (DMatrixHandle) jhandle; + DMatrixHandle result; + DMatrixHandle handle = (DMatrixHandle) jhandle; - jint* indexset = jenv->GetIntArrayElements(jindexset, 0); - bst_ulong len = (bst_ulong)jenv->GetArrayLength(jindexset); + jint* indexset = jenv->GetIntArrayElements(jindexset, 0); + bst_ulong len = (bst_ulong)jenv->GetArrayLength(jindexset); - int ret = XGDMatrixSliceDMatrix(handle, (int const *)indexset, len, &result); - setHandle(jenv, jout, result); - //release - jenv->ReleaseIntArrayElements(jindexset, indexset, 0); - - return ret; + int ret = XGDMatrixSliceDMatrix(handle, (int const *)indexset, len, &result); + setHandle(jenv, jout, result); + //release + jenv->ReleaseIntArrayElements(jindexset, indexset, 0); + + return ret; } /* @@ -133,10 +138,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSliceDMatrix * Signature: (J)V */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixFree - (JNIEnv *jenv, jclass jcls, jlong jhandle) { - DMatrixHandle handle = (DMatrixHandle) jhandle; - int ret = XGDMatrixFree(handle); - return ret; + (JNIEnv *jenv, jclass jcls, jlong jhandle) { + DMatrixHandle handle = (DMatrixHandle) jhandle; + int ret = XGDMatrixFree(handle); + return ret; } /* @@ -146,11 +151,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixFree */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSaveBinary (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname, jint jsilent) { - DMatrixHandle handle = (DMatrixHandle) jhandle; - const char* fname = jenv->GetStringUTFChars(jfname, 0); - int ret = XGDMatrixSaveBinary(handle, fname, jsilent); - if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname); - return ret; + DMatrixHandle handle = (DMatrixHandle) jhandle; + const char* fname = jenv->GetStringUTFChars(jfname, 0); + int ret = XGDMatrixSaveBinary(handle, fname, jsilent); + if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname); + return ret; } /* @@ -160,16 +165,16 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSaveBinary */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetFloatInfo (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jfloatArray jarray) { - DMatrixHandle handle = (DMatrixHandle) jhandle; - const char* field = jenv->GetStringUTFChars(jfield, 0); - - jfloat* array = jenv->GetFloatArrayElements(jarray, NULL); - bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray); - int ret = XGDMatrixSetFloatInfo(handle, field, (float const *)array, len); - //release - if (field) jenv->ReleaseStringUTFChars(jfield, field); - jenv->ReleaseFloatArrayElements(jarray, array, 0); - return ret; + DMatrixHandle handle = (DMatrixHandle) jhandle; + const char* field = jenv->GetStringUTFChars(jfield, 0); + + jfloat* array = jenv->GetFloatArrayElements(jarray, NULL); + bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray); + int ret = XGDMatrixSetFloatInfo(handle, field, (float const *)array, len); + //release + if (field) jenv->ReleaseStringUTFChars(jfield, field); + jenv->ReleaseFloatArrayElements(jarray, array, 0); + return ret; } /* @@ -179,16 +184,16 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetFloatInfo */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetUIntInfo (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jintArray jarray) { - DMatrixHandle handle = (DMatrixHandle) jhandle; - const char* field = jenv->GetStringUTFChars(jfield, 0); - jint* array = jenv->GetIntArrayElements(jarray, NULL); - bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray); - int ret = XGDMatrixSetUIntInfo(handle, (char const *)field, (unsigned int const *)array, len); - //release - if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field); - jenv->ReleaseIntArrayElements(jarray, array, 0); - - return ret; + DMatrixHandle handle = (DMatrixHandle) jhandle; + const char* field = jenv->GetStringUTFChars(jfield, 0); + jint* array = jenv->GetIntArrayElements(jarray, NULL); + bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray); + int ret = XGDMatrixSetUIntInfo(handle, (char const *)field, (unsigned int const *)array, len); + //release + if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field); + jenv->ReleaseIntArrayElements(jarray, array, 0); + + return ret; } /* @@ -198,13 +203,13 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetUIntInfo */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetGroup (JNIEnv * jenv, jclass jcls, jlong jhandle, jintArray jarray) { - DMatrixHandle handle = (DMatrixHandle) jhandle; - jint* array = jenv->GetIntArrayElements(jarray, NULL); - bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray); - int ret = XGDMatrixSetGroup(handle, (unsigned int const *)array, len); - //release - jenv->ReleaseIntArrayElements(jarray, array, 0); - return ret; + DMatrixHandle handle = (DMatrixHandle) jhandle; + jint* array = jenv->GetIntArrayElements(jarray, NULL); + bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray); + int ret = XGDMatrixSetGroup(handle, (unsigned int const *)array, len); + //release + jenv->ReleaseIntArrayElements(jarray, array, 0); + return ret; } /* @@ -214,19 +219,19 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetGroup */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetFloatInfo (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) { - DMatrixHandle handle = (DMatrixHandle) jhandle; - const char* field = jenv->GetStringUTFChars(jfield, 0); - bst_ulong len; - float *result; - int ret = XGDMatrixGetFloatInfo(handle, field, &len, (const float**) &result); - if (field) jenv->ReleaseStringUTFChars(jfield, field); - - jsize jlen = (jsize) len; - jfloatArray jarray = jenv->NewFloatArray(jlen); - jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) result); - jenv->SetObjectArrayElement(jout, 0, (jobject) jarray); - - return ret; + DMatrixHandle handle = (DMatrixHandle) jhandle; + const char* field = jenv->GetStringUTFChars(jfield, 0); + bst_ulong len; + float *result; + int ret = XGDMatrixGetFloatInfo(handle, field, &len, (const float**) &result); + if (field) jenv->ReleaseStringUTFChars(jfield, field); + + jsize jlen = (jsize) len; + jfloatArray jarray = jenv->NewFloatArray(jlen); + jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) result); + jenv->SetObjectArrayElement(jout, 0, (jobject) jarray); + + return ret; } /* @@ -236,18 +241,18 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetFloatInfo */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetUIntInfo (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) { - DMatrixHandle handle = (DMatrixHandle) jhandle; - const char* field = jenv->GetStringUTFChars(jfield, 0); - bst_ulong len; - unsigned int *result; - int ret = (jint) XGDMatrixGetUIntInfo(handle, field, &len, (const unsigned int **) &result); - if (field) jenv->ReleaseStringUTFChars(jfield, field); - - jsize jlen = (jsize) len; - jintArray jarray = jenv->NewIntArray(jlen); - jenv->SetIntArrayRegion(jarray, 0, jlen, (jint *) result); - jenv->SetObjectArrayElement(jout, 0, jarray); - return ret; + DMatrixHandle handle = (DMatrixHandle) jhandle; + const char* field = jenv->GetStringUTFChars(jfield, 0); + bst_ulong len; + unsigned int *result; + int ret = (jint) XGDMatrixGetUIntInfo(handle, field, &len, (const unsigned int **) &result); + if (field) jenv->ReleaseStringUTFChars(jfield, field); + + jsize jlen = (jsize) len; + jintArray jarray = jenv->NewIntArray(jlen); + jenv->SetIntArrayRegion(jarray, 0, jlen, (jint *) result); + jenv->SetObjectArrayElement(jout, 0, jarray); + return ret; } /* @@ -257,11 +262,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetUIntInfo */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixNumRow (JNIEnv *jenv, jclass jcls, jlong jhandle, jlongArray jout) { - DMatrixHandle handle = (DMatrixHandle) jhandle; - bst_ulong result[1]; - int ret = (jint) XGDMatrixNumRow(handle, result); - jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) result); - return ret; + DMatrixHandle handle = (DMatrixHandle) jhandle; + bst_ulong result[1]; + int ret = (jint) XGDMatrixNumRow(handle, result); + jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) result); + return ret; } /* @@ -271,30 +276,29 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixNumRow */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate (JNIEnv *jenv, jclass jcls, jlongArray jhandles, jlongArray jout) { - DMatrixHandle* handles; - bst_ulong len = 0; - jlong* cjhandles = 0; - BoosterHandle result; - - if(jhandles) { - len = (bst_ulong)jenv->GetArrayLength(jhandles); - handles = new DMatrixHandle[len]; - //put handle from jhandles to chandles - cjhandles = jenv->GetLongArrayElements(jhandles, 0); - for(bst_ulong i=0; iGetArrayLength(jhandles); + handles = new DMatrixHandle[len]; + //put handle from jhandles to chandles + cjhandles = jenv->GetLongArrayElements(jhandles, 0); + for(bst_ulong i=0; iReleaseLongArrayElements(jhandles, cjhandles, 0); - } - setHandle(jenv, jout, result); - - return ret; + } + + int ret = XGBoosterCreate(handles, len, &result); + //release + if (jhandles) { + delete[] handles; + jenv->ReleaseLongArrayElements(jhandles, cjhandles, 0); + } + setHandle(jenv, jout, result); + return ret; } /* @@ -316,14 +320,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterFree */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSetParam (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jname, jstring jvalue) { - BoosterHandle handle = (BoosterHandle) jhandle; - const char* name = jenv->GetStringUTFChars(jname, 0); - const char* value = jenv->GetStringUTFChars(jvalue, 0); - int ret = XGBoosterSetParam(handle, name, value); - //release - if (name) jenv->ReleaseStringUTFChars(jname, name); - if (value) jenv->ReleaseStringUTFChars(jvalue, value); - return ret; + BoosterHandle handle = (BoosterHandle) jhandle; + const char* name = jenv->GetStringUTFChars(jname, 0); + const char* value = jenv->GetStringUTFChars(jvalue, 0); + int ret = XGBoosterSetParam(handle, name, value); + //release + if (name) jenv->ReleaseStringUTFChars(jname, name); + if (value) jenv->ReleaseStringUTFChars(jvalue, value); + return ret; } /* @@ -333,9 +337,9 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSetParam */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterUpdateOneIter (JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlong jdtrain) { - BoosterHandle handle = (BoosterHandle) jhandle; - DMatrixHandle dtrain = (DMatrixHandle) jdtrain; - return XGBoosterUpdateOneIter(handle, jiter, dtrain); + BoosterHandle handle = (BoosterHandle) jhandle; + DMatrixHandle dtrain = (DMatrixHandle) jdtrain; + return XGBoosterUpdateOneIter(handle, jiter, dtrain); } /* @@ -345,16 +349,16 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterUpdateOneIter */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterBoostOneIter (JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jfloatArray jgrad, jfloatArray jhess) { - BoosterHandle handle = (BoosterHandle) jhandle; - DMatrixHandle dtrain = (DMatrixHandle) jdtrain; - jfloat* grad = jenv->GetFloatArrayElements(jgrad, 0); - jfloat* hess = jenv->GetFloatArrayElements(jhess, 0); - bst_ulong len = (bst_ulong)jenv->GetArrayLength(jgrad); - int ret = XGBoosterBoostOneIter(handle, dtrain, grad, hess, len); - //release - jenv->ReleaseFloatArrayElements(jgrad, grad, 0); - jenv->ReleaseFloatArrayElements(jhess, hess, 0); - return ret; + BoosterHandle handle = (BoosterHandle) jhandle; + DMatrixHandle dtrain = (DMatrixHandle) jdtrain; + jfloat* grad = jenv->GetFloatArrayElements(jgrad, 0); + jfloat* hess = jenv->GetFloatArrayElements(jhess, 0); + bst_ulong len = (bst_ulong)jenv->GetArrayLength(jgrad); + int ret = XGBoosterBoostOneIter(handle, dtrain, grad, hess, len); + //release + jenv->ReleaseFloatArrayElements(jgrad, grad, 0); + jenv->ReleaseFloatArrayElements(jhess, hess, 0); + return ret; } /* @@ -364,45 +368,45 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterBoostOneIter */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterEvalOneIter (JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlongArray jdmats, jobjectArray jevnames, jobjectArray jout) { - BoosterHandle handle = (BoosterHandle) jhandle; - DMatrixHandle* dmats = 0; - char **evnames = 0; - char *result = 0; - bst_ulong len = (bst_ulong)jenv->GetArrayLength(jdmats); - if(len > 0) { - dmats = new DMatrixHandle[len]; - evnames = new char*[len]; - } - //put handle from jhandles to chandles - jlong* cjdmats = jenv->GetLongArrayElements(jdmats, 0); + BoosterHandle handle = (BoosterHandle) jhandle; + DMatrixHandle* dmats = 0; + char **evnames = 0; + char *result = 0; + bst_ulong len = (bst_ulong)jenv->GetArrayLength(jdmats); + if(len > 0) { + dmats = new DMatrixHandle[len]; + evnames = new char*[len]; + } + //put handle from jhandles to chandles + jlong* cjdmats = jenv->GetLongArrayElements(jdmats, 0); + for(bst_ulong i=0; iGetObjectArrayElement(jevnames, i); + const char* cevname = jenv->GetStringUTFChars(jevname, 0); + evnames[i] = new char[jenv->GetStringLength(jevname)]; + strcpy(evnames[i], cevname); + jenv->ReleaseStringUTFChars(jevname, cevname); + } + + int ret = XGBoosterEvalOneIter(handle, jiter, dmats, (char const *(*)) evnames, len, (const char **) &result); + if(len > 0) { + delete[] dmats; + //release string chars for(bst_ulong i=0; iGetObjectArrayElement(jevnames, i); - const char* cevname = jenv->GetStringUTFChars(jevname, 0); - evnames[i] = new char[jenv->GetStringLength(jevname)]; - strcpy(evnames[i], cevname); - jenv->ReleaseStringUTFChars(jevname, cevname); - } - - int ret = XGBoosterEvalOneIter(handle, jiter, dmats, (char const *(*)) evnames, len, (const char **) &result); - if(len > 0) { - delete[] dmats; - //release string chars - for(bst_ulong i=0; iReleaseLongArrayElements(jdmats, cjdmats, 0); - } - - jstring jinfo = 0; - if (result) jinfo = jenv->NewStringUTF((const char *) result); - jenv->SetObjectArrayElement(jout, 0, jinfo); - - return ret; + delete[] evnames; + jenv->ReleaseLongArrayElements(jdmats, cjdmats, 0); + } + + jstring jinfo = 0; + if (result) jinfo = jenv->NewStringUTF((const char *) result); + jenv->SetObjectArrayElement(jout, 0, jinfo); + + return ret; } /* @@ -412,17 +416,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterEvalOneIter */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterPredict (JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdmat, jint joption_mask, jint jntree_limit, jobjectArray jout) { - BoosterHandle handle = (BoosterHandle) jhandle; - DMatrixHandle dmat = (DMatrixHandle) jdmat; - bst_ulong len; - float *result; - int ret = XGBoosterPredict(handle, dmat, joption_mask, (unsigned int) jntree_limit, &len, (const float **) &result); - - jsize jlen = (jsize) len; - jfloatArray jarray = jenv->NewFloatArray(jlen); - jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) result); - jenv->SetObjectArrayElement(jout, 0, jarray); - return ret; + BoosterHandle handle = (BoosterHandle) jhandle; + DMatrixHandle dmat = (DMatrixHandle) jdmat; + bst_ulong len; + float *result; + int ret = XGBoosterPredict(handle, dmat, joption_mask, (unsigned int) jntree_limit, &len, (const float **) &result); + + jsize jlen = (jsize) len; + jfloatArray jarray = jenv->NewFloatArray(jlen); + jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) result); + jenv->SetObjectArrayElement(jout, 0, jarray); + return ret; } /* @@ -432,12 +436,12 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterPredict */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModel (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) { - BoosterHandle handle = (BoosterHandle) jhandle; - const char* fname = jenv->GetStringUTFChars(jfname, 0); - - int ret = XGBoosterLoadModel(handle, fname); - if (fname) jenv->ReleaseStringUTFChars(jfname,fname); - return ret; + BoosterHandle handle = (BoosterHandle) jhandle; + const char* fname = jenv->GetStringUTFChars(jfname, 0); + + int ret = XGBoosterLoadModel(handle, fname); + if (fname) jenv->ReleaseStringUTFChars(jfname,fname); + return ret; } /* @@ -447,13 +451,13 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModel */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveModel (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) { - BoosterHandle handle = (BoosterHandle) jhandle; - const char* fname = jenv->GetStringUTFChars(jfname, 0); - - int ret = XGBoosterSaveModel(handle, fname); - if (fname) jenv->ReleaseStringUTFChars(jfname, fname); - - return ret; + BoosterHandle handle = (BoosterHandle) jhandle; + const char* fname = jenv->GetStringUTFChars(jfname, 0); + + int ret = XGBoosterSaveModel(handle, fname); + if (fname) jenv->ReleaseStringUTFChars(jfname, fname); + + return ret; } /* @@ -463,9 +467,9 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveModel */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromBuffer (JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jbuf, jlong jlen) { - BoosterHandle handle = (BoosterHandle) jhandle; - void *buf = (void*) jbuf; - return XGBoosterLoadModelFromBuffer(handle, (void const *)buf, (bst_ulong) jlen); + BoosterHandle handle = (BoosterHandle) jhandle; + void *buf = (void*) jbuf; + return XGBoosterLoadModelFromBuffer(handle, (void const *)buf, (bst_ulong) jlen); } /* @@ -475,16 +479,16 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromB */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw (JNIEnv * jenv, jclass jcls, jlong jhandle, jobjectArray jout) { - BoosterHandle handle = (BoosterHandle) jhandle; - bst_ulong len = 0; - char *result; + BoosterHandle handle = (BoosterHandle) jhandle; + bst_ulong len = 0; + char *result; - int ret = XGBoosterGetModelRaw(handle, &len, (const char **) &result); - if (result){ - jstring jinfo = jenv->NewStringUTF((const char *) result); - jenv->SetObjectArrayElement(jout, 0, jinfo); - } - return ret; + int ret = XGBoosterGetModelRaw(handle, &len, (const char **) &result); + if (result) { + jstring jinfo = jenv->NewStringUTF((const char *) result); + jenv->SetObjectArrayElement(jout, 0, jinfo); + } + return ret; } /* @@ -494,20 +498,129 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfmap, jint jwith_stats, jobjectArray jout) { - BoosterHandle handle = (BoosterHandle) jhandle; - const char *fmap = jenv->GetStringUTFChars(jfmap, 0); - bst_ulong len = 0; - char **result; - - int ret = XGBoosterDumpModel(handle, fmap, jwith_stats, &len, (const char ***) &result); - - jsize jlen = (jsize) len; - jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF("")); - for(int i=0 ; iSetObjectArrayElement(jinfos, i, jenv->NewStringUTF((const char*) result[i])); - } - jenv->SetObjectArrayElement(jout, 0, jinfos); - - if (fmap) jenv->ReleaseStringUTFChars(jfmap, (const char *)fmap); - return ret; + BoosterHandle handle = (BoosterHandle) jhandle; + const char *fmap = jenv->GetStringUTFChars(jfmap, 0); + bst_ulong len = 0; + char **result; + + int ret = XGBoosterDumpModel(handle, fmap, jwith_stats, &len, (const char ***) &result); + + jsize jlen = (jsize) len; + jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF("")); + for(int i=0 ; iSetObjectArrayElement(jinfos, i, jenv->NewStringUTF((const char*) result[i])); + } + jenv->SetObjectArrayElement(jout, 0, jinfos); + + if (fmap) jenv->ReleaseStringUTFChars(jfmap, (const char *)fmap); + return ret; +} + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGBoosterLoadRabitCheckpoint + * Signature: (J[I)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadRabitCheckpoint + (JNIEnv *jenv , jclass jcls, jlong jhandle, jintArray jout) { + BoosterHandle handle = (BoosterHandle) jhandle; + int version; + int ret = XGBoosterLoadRabitCheckpoint(handle, &version); + jenv->SetIntArrayRegion(jout, 0, 1, &version); + return ret; +} + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGBoosterSaveRabitCheckpoint + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveRabitCheckpoint + (JNIEnv *jenv, jclass jcls, jlong jhandle) { + BoosterHandle handle = (BoosterHandle) jhandle; + return XGBoosterSaveRabitCheckpoint(handle); +} + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: RabitInit + * Signature: ([Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitInit + (JNIEnv *jenv, jclass jcls, jobjectArray jargs) { + std::vector args; + std::vector argv; + bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs); + for (bst_ulong i = 0; i < len; ++i) { + jstring arg = (jstring)jenv->GetObjectArrayElement(jargs, i); + std::string s(jenv->GetStringUTFChars(arg, 0), + jenv->GetStringLength(arg)); + if (s.length() != 0) args.push_back(s); + } + + for (size_t i = 0; i < args.size(); ++i) { + argv.push_back(&args[i][0]); + } + RabitInit(args.size(), args.size() == 0 ? NULL : &argv[0]); + return 0; +} + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: RabitFinalize + * Signature: ()I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitFinalize + (JNIEnv *jenv, jclass jcls) { + RabitFinalize(); + return 0; +} + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: RabitTrackerPrint + * Signature: (Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitTrackerPrint + (JNIEnv *jenv, jclass jcls, jstring jmsg) { + std::string str(jenv->GetStringUTFChars(jmsg, 0), + jenv->GetStringLength(jmsg)); + RabitTrackerPrint(str.c_str()); + return 0; +} + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: RabitGetRank + * Signature: ([I)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetRank + (JNIEnv *jenv, jclass jcls, jintArray jout) { + int rank = RabitGetRank(); + jenv->SetIntArrayRegion(jout, 0, 1, &rank); + return 0; +} + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: RabitGetWorldSize + * Signature: ([I)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetWorldSize + (JNIEnv *jenv, jclass jcls, jintArray jout) { + int out = RabitGetWorldSize(); + jenv->SetIntArrayRegion(jout, 0, 1, &out); + return 0; +} + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: RabitVersionNumber + * Signature: ([I)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitVersionNumber + (JNIEnv *jenv, jclass jcls, jintArray jout) { + int out = RabitVersionNumber(); + jenv->SetIntArrayRegion(jout, 0, 1, &out); + return 0; } diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.h b/jvm-packages/xgboost4j/src/native/xgboost4j.h index d93da0ee6..023827c44 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.h +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.h @@ -215,6 +215,86 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel (JNIEnv *, jclass, jlong, jstring, jint, jobjectArray); +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGBoosterGetAttr + * Signature: (JLjava/lang/String;[Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetAttr + (JNIEnv *, jclass, jlong, jstring, jobjectArray); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGBoosterSetAttr + * Signature: (JLjava/lang/String;Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSetAttr + (JNIEnv *, jclass, jlong, jstring, jstring); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGBoosterLoadRabitCheckpoint + * Signature: (J[I)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadRabitCheckpoint + (JNIEnv *, jclass, jlong, jintArray); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: XGBoosterSaveRabitCheckpoint + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveRabitCheckpoint + (JNIEnv *, jclass, jlong); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: RabitInit + * Signature: ([Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitInit + (JNIEnv *, jclass, jobjectArray); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: RabitFinalize + * Signature: ()I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitFinalize + (JNIEnv *, jclass); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: RabitTrackerPrint + * Signature: (Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitTrackerPrint + (JNIEnv *, jclass, jstring); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: RabitGetRank + * Signature: ([I)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetRank + (JNIEnv *, jclass, jintArray); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: RabitGetWorldSize + * Signature: ([I)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetWorldSize + (JNIEnv *, jclass, jintArray); + +/* + * Class: ml_dmlc_xgboost4j_XgboostJNI + * Method: RabitVersionNumber + * Signature: ([I)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitVersionNumber + (JNIEnv *, jclass, jintArray); + #ifdef __cplusplus } #endif