[JVM-PKG] Update JNI to include rabit codes

This commit is contained in:
tqchen 2016-03-02 22:12:17 -08:00
parent ced6d45e01
commit 5c9e50148a
8 changed files with 609 additions and 280 deletions

View File

@ -21,7 +21,7 @@ import org.apache.commons.logging.LogFactory;
import java.io.IOException; import java.io.IOException;
/** /**
* DMatrix for xgboost, similar to the python wrapper xgboost.py * DMatrix for xgboost.
* *
* @author hzx * @author hzx
*/ */

View File

@ -48,4 +48,5 @@ class JNIErrorHandle {
throw new XGBoostError(XgboostJNI.XGBGetLastError()); throw new XGBoostError(XgboostJNI.XGBGetLastError());
} }
} }
} }

View File

@ -441,6 +441,27 @@ class JavaBoosterImpl implements Booster {
return featureScore; return featureScore;
} }
/**
* Load the booster model from thread-local rabit checkpoint.
* This is only used in distributed training.
* @return the stored version number of the checkpoint.
* @throws XGBoostError
*/
int loadRabitCheckpoint() throws XGBoostError {
int[] out = new int[1];
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
return out[0];
}
/**
* Save the booster model into thread-local rabit checkpoint.
* This is only used in distributed training.
* @throws XGBoostError
*/
void saveRabitCheckpoint() throws XGBoostError {
JNIErrorHandle.checkCall(XgboostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
}
/** /**
* transfer DMatrix array to handle array (used for native functions) * transfer DMatrix array to handle array (used for native functions)
* *

View File

@ -0,0 +1,93 @@
package ml.dmlc.xgboost4j;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import java.util.Map;
import java.io.IOException;
/**
* Rabit global class for synchronization.
*/
public class Rabit {
private static final Log logger = LogFactory.getLog(DMatrix.class);
//load native library
static {
try {
NativeLibLoader.initXgBoost();
} catch (IOException ex) {
logger.error("load native library failed.");
logger.error(ex);
}
}
private static void checkCall(int ret) throws XGBoostError {
if (ret != 0) {
throw new XGBoostError(XgboostJNI.XGBGetLastError());
}
}
/**
* Initialize the rabit library on current working thread.
* @param envs The additional environment variables to pass to rabit.
* @throws XGBoostError
*/
public static void init(Map<String, String> envs) throws XGBoostError {
String[] args = new String[envs.size()];
int idx = 0;
for (java.util.Map.Entry<String, String> e : envs.entrySet()) {
args[idx++] = e.getKey() + '=' + e.getValue();
}
checkCall(XgboostJNI.RabitInit(args));
}
/**
* Shutdown the rabit engine in current working thread, equals to finalize.
* @throws XGBoostError
*/
public static void shutdown() throws XGBoostError {
checkCall(XgboostJNI.RabitFinalize());
}
/**
* Print the message on rabit tracker.
* @param msg
* @throws XGBoostError
*/
public static void trackerPrint(String msg) throws XGBoostError {
checkCall(XgboostJNI.RabitTrackerPrint(msg));
}
/**
* Get version number of current stored model in the thread.
* which means how many calls to CheckPoint we made so far.
* @return version Number.
* @throws XGBoostError
*/
public static int versionNumber() throws XGBoostError {
int[] out = new int[1];
checkCall(XgboostJNI.RabitVersionNumber(out));
return out[0];
}
/**
* get rank of current thread.
* @return the rank.
* @throws XGBoostError
*/
public static int getRank() throws XGBoostError {
int[] out = new int[1];
checkCall(XgboostJNI.RabitGetRank(out));
return out[0];
}
/**
* get world size of current job.
* @return the worldsize
* @throws XGBoostError
*/
public static int getWorldSize() throws XGBoostError {
int[] out = new int[1];
checkCall(XgboostJNI.RabitGetWorldSize(out));
return out[0];
}
}

View File

@ -72,14 +72,20 @@ public class XGBoost {
} }
//initialize booster //initialize booster
Booster booster = new JavaBoosterImpl(params, allMats); JavaBoosterImpl booster = new JavaBoosterImpl(params, allMats);
int version = booster.loadRabitCheckpoint();
//begin to train //begin to train
for (int iter = 0; iter < round; iter++) { for (int iter = version / 2; iter < round; iter++) {
if (obj != null) { if (version % 2 == 0) {
booster.update(dtrain, obj); if (obj != null) {
} else { booster.update(dtrain, obj);
booster.update(dtrain, iter); } else {
booster.update(dtrain, iter);
}
booster.saveRabitCheckpoint();
version += 1;
} }
//evaluation //evaluation
@ -92,6 +98,8 @@ public class XGBoost {
} }
logger.info(evalInfo); logger.info(evalInfo);
} }
booster.saveRabitCheckpoint();
version += 1;
} }
return booster; return booster;
} }

View File

@ -16,7 +16,7 @@
package ml.dmlc.xgboost4j; package ml.dmlc.xgboost4j;
/** /**
* xgboost jni wrapper functions for xgboost_wrapper.h * xgboost JNI functions
* change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster * change 2015-7-6: *use a long[] (length=1) as container of handle to get the output DMatrix or Booster
* *
* @author hzx * @author hzx
@ -80,4 +80,17 @@ class XgboostJNI {
public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats, public final static native int XGBoosterDumpModel(long handle, String fmap, int with_stats,
String[][] out_strings); String[][] out_strings);
public final static native int XGBoosterGetAttr(long handle, String key, String[] out_string);
public final static native int XGBoosterSetAttr(long handle, String key, String value);
public final static native int XGBoosterLoadRabitCheckpoint(long handle, int[] out_version);
public final static native int XGBoosterSaveRabitCheckpoint(long handle);
// rabit functions
public final static native int RabitInit(String[] args);
public final static native int RabitFinalize();
public final static native int RabitTrackerPrint(String msg);
public final static native int RabitGetRank(int[] out);
public final static native int RabitGetWorldSize(int[] out);
public final static native int RabitVersionNumber(int[] out);
} }

View File

@ -1,45 +1,50 @@
/* /*
Copyright (c) 2014 by Contributors Copyright (c) 2014 by Contributors
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "xgboost/c_api.h" http://www.apache.org/licenses/LICENSE-2.0
#include "xgboost4j.h" Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <xgboost/c_api.h>
#include "./xgboost4j.h"
#include <cstring> #include <cstring>
#include <vector>
#include <string>
//helper functions //helper functions
//set handle //set handle
void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) { void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) {
long out[1]; long out = (long) handle;
out[0] = (long) handle; jenv->SetLongArrayRegion(jhandle, 0, 1, &out);
jenv->SetLongArrayRegion(jhandle, 0, 1, (const jlong*) out);
} }
JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBGetLastError
(JNIEnv *jenv, jclass jcls) { (JNIEnv *jenv, jclass jcls) {
jstring jresult = 0 ; jstring jresult = 0;
const char* result = XGBGetLastError(); const char* result = XGBGetLastError();
if (result) jresult = jenv->NewStringUTF(result); if (result != NULL) {
return jresult; jresult = jenv->NewStringUTF(result);
}
return jresult;
} }
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile
(JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent, jlongArray jout) { (JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent, jlongArray jout) {
DMatrixHandle result; DMatrixHandle result;
const char* fname = jenv->GetStringUTFChars(jfname, 0); const char* fname = jenv->GetStringUTFChars(jfname, 0);
int ret = XGDMatrixCreateFromFile(fname, jsilent, &result); int ret = XGDMatrixCreateFromFile(fname, jsilent, &result);
if (fname) jenv->ReleaseStringUTFChars(jfname, fname); if (fname) {
setHandle(jenv, jout, result); jenv->ReleaseStringUTFChars(jfname, fname);
return ret; }
setHandle(jenv, jout, result);
return ret;
} }
/* /*
@ -49,19 +54,19 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromFile
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSR JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSR
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) { (JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) {
DMatrixHandle result; DMatrixHandle result;
jlong* indptr = jenv->GetLongArrayElements(jindptr, 0); jlong* indptr = jenv->GetLongArrayElements(jindptr, 0);
jint* indices = jenv->GetIntArrayElements(jindices, 0); jint* indices = jenv->GetIntArrayElements(jindices, 0);
jfloat* data = jenv->GetFloatArrayElements(jdata, 0); jfloat* data = jenv->GetFloatArrayElements(jdata, 0);
bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr); bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata); bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata);
int ret = (jint) XGDMatrixCreateFromCSR((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, &result); int ret = (jint) XGDMatrixCreateFromCSR((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, &result);
setHandle(jenv, jout, result); setHandle(jenv, jout, result);
//Release //Release
jenv->ReleaseLongArrayElements(jindptr, indptr, 0); jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
jenv->ReleaseIntArrayElements(jindices, indices, 0); jenv->ReleaseIntArrayElements(jindices, indices, 0);
jenv->ReleaseFloatArrayElements(jdata, data, 0); jenv->ReleaseFloatArrayElements(jdata, data, 0);
return ret; return ret;
} }
/* /*
@ -71,21 +76,21 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSR
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSC JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSC
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) { (JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) {
DMatrixHandle result; DMatrixHandle result;
jlong* indptr = jenv->GetLongArrayElements(jindptr, NULL); jlong* indptr = jenv->GetLongArrayElements(jindptr, NULL);
jint* indices = jenv->GetIntArrayElements(jindices, 0); jint* indices = jenv->GetIntArrayElements(jindices, 0);
jfloat* data = jenv->GetFloatArrayElements(jdata, NULL); jfloat* data = jenv->GetFloatArrayElements(jdata, NULL);
bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr); bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata); bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata);
int ret = (jint) XGDMatrixCreateFromCSC((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, &result); int ret = (jint) XGDMatrixCreateFromCSC((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, &result);
setHandle(jenv, jout, result); setHandle(jenv, jout, result);
//release //release
jenv->ReleaseLongArrayElements(jindptr, indptr, 0); jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
jenv->ReleaseIntArrayElements(jindices, indices, 0); jenv->ReleaseIntArrayElements(jindices, indices, 0);
jenv->ReleaseFloatArrayElements(jdata, data, 0); jenv->ReleaseFloatArrayElements(jdata, data, 0);
return ret; return ret;
} }
/* /*
@ -95,15 +100,15 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromCSC
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromMat JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromMat
(JNIEnv *jenv, jclass jcls, jfloatArray jdata, jint jnrow, jint jncol, jfloat jmiss, jlongArray jout) { (JNIEnv *jenv, jclass jcls, jfloatArray jdata, jint jnrow, jint jncol, jfloat jmiss, jlongArray jout) {
DMatrixHandle result; DMatrixHandle result;
jfloat* data = jenv->GetFloatArrayElements(jdata, 0); jfloat* data = jenv->GetFloatArrayElements(jdata, 0);
bst_ulong nrow = (bst_ulong)jnrow; bst_ulong nrow = (bst_ulong)jnrow;
bst_ulong ncol = (bst_ulong)jncol; bst_ulong ncol = (bst_ulong)jncol;
int ret = (jint) XGDMatrixCreateFromMat((float const *)data, nrow, ncol, jmiss, &result); int ret = (jint) XGDMatrixCreateFromMat((float const *)data, nrow, ncol, jmiss, &result);
setHandle(jenv, jout, result); setHandle(jenv, jout, result);
//release //release
jenv->ReleaseFloatArrayElements(jdata, data, 0); jenv->ReleaseFloatArrayElements(jdata, data, 0);
return ret; return ret;
} }
/* /*
@ -113,18 +118,18 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixCreateFromMat
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSliceDMatrix JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSliceDMatrix
(JNIEnv *jenv, jclass jcls, jlong jhandle, jintArray jindexset, jlongArray jout) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jintArray jindexset, jlongArray jout) {
DMatrixHandle result; DMatrixHandle result;
DMatrixHandle handle = (DMatrixHandle) jhandle; DMatrixHandle handle = (DMatrixHandle) jhandle;
jint* indexset = jenv->GetIntArrayElements(jindexset, 0); jint* indexset = jenv->GetIntArrayElements(jindexset, 0);
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jindexset); bst_ulong len = (bst_ulong)jenv->GetArrayLength(jindexset);
int ret = XGDMatrixSliceDMatrix(handle, (int const *)indexset, len, &result); int ret = XGDMatrixSliceDMatrix(handle, (int const *)indexset, len, &result);
setHandle(jenv, jout, result); setHandle(jenv, jout, result);
//release //release
jenv->ReleaseIntArrayElements(jindexset, indexset, 0); jenv->ReleaseIntArrayElements(jindexset, indexset, 0);
return ret; return ret;
} }
/* /*
@ -133,10 +138,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSliceDMatrix
* Signature: (J)V * Signature: (J)V
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixFree JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixFree
(JNIEnv *jenv, jclass jcls, jlong jhandle) { (JNIEnv *jenv, jclass jcls, jlong jhandle) {
DMatrixHandle handle = (DMatrixHandle) jhandle; DMatrixHandle handle = (DMatrixHandle) jhandle;
int ret = XGDMatrixFree(handle); int ret = XGDMatrixFree(handle);
return ret; return ret;
} }
/* /*
@ -146,11 +151,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixFree
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSaveBinary JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSaveBinary
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname, jint jsilent) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname, jint jsilent) {
DMatrixHandle handle = (DMatrixHandle) jhandle; DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* fname = jenv->GetStringUTFChars(jfname, 0); const char* fname = jenv->GetStringUTFChars(jfname, 0);
int ret = XGDMatrixSaveBinary(handle, fname, jsilent); int ret = XGDMatrixSaveBinary(handle, fname, jsilent);
if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname); if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname);
return ret; return ret;
} }
/* /*
@ -160,16 +165,16 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSaveBinary
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetFloatInfo JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetFloatInfo
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jfloatArray jarray) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jfloatArray jarray) {
DMatrixHandle handle = (DMatrixHandle) jhandle; DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* field = jenv->GetStringUTFChars(jfield, 0); const char* field = jenv->GetStringUTFChars(jfield, 0);
jfloat* array = jenv->GetFloatArrayElements(jarray, NULL); jfloat* array = jenv->GetFloatArrayElements(jarray, NULL);
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray); bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray);
int ret = XGDMatrixSetFloatInfo(handle, field, (float const *)array, len); int ret = XGDMatrixSetFloatInfo(handle, field, (float const *)array, len);
//release //release
if (field) jenv->ReleaseStringUTFChars(jfield, field); if (field) jenv->ReleaseStringUTFChars(jfield, field);
jenv->ReleaseFloatArrayElements(jarray, array, 0); jenv->ReleaseFloatArrayElements(jarray, array, 0);
return ret; return ret;
} }
/* /*
@ -179,16 +184,16 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetFloatInfo
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetUIntInfo JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetUIntInfo
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jintArray jarray) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jintArray jarray) {
DMatrixHandle handle = (DMatrixHandle) jhandle; DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* field = jenv->GetStringUTFChars(jfield, 0); const char* field = jenv->GetStringUTFChars(jfield, 0);
jint* array = jenv->GetIntArrayElements(jarray, NULL); jint* array = jenv->GetIntArrayElements(jarray, NULL);
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray); bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray);
int ret = XGDMatrixSetUIntInfo(handle, (char const *)field, (unsigned int const *)array, len); int ret = XGDMatrixSetUIntInfo(handle, (char const *)field, (unsigned int const *)array, len);
//release //release
if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field); if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field);
jenv->ReleaseIntArrayElements(jarray, array, 0); jenv->ReleaseIntArrayElements(jarray, array, 0);
return ret; return ret;
} }
/* /*
@ -198,13 +203,13 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetUIntInfo
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetGroup JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetGroup
(JNIEnv * jenv, jclass jcls, jlong jhandle, jintArray jarray) { (JNIEnv * jenv, jclass jcls, jlong jhandle, jintArray jarray) {
DMatrixHandle handle = (DMatrixHandle) jhandle; DMatrixHandle handle = (DMatrixHandle) jhandle;
jint* array = jenv->GetIntArrayElements(jarray, NULL); jint* array = jenv->GetIntArrayElements(jarray, NULL);
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray); bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray);
int ret = XGDMatrixSetGroup(handle, (unsigned int const *)array, len); int ret = XGDMatrixSetGroup(handle, (unsigned int const *)array, len);
//release //release
jenv->ReleaseIntArrayElements(jarray, array, 0); jenv->ReleaseIntArrayElements(jarray, array, 0);
return ret; return ret;
} }
/* /*
@ -214,19 +219,19 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixSetGroup
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetFloatInfo JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetFloatInfo
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) {
DMatrixHandle handle = (DMatrixHandle) jhandle; DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* field = jenv->GetStringUTFChars(jfield, 0); const char* field = jenv->GetStringUTFChars(jfield, 0);
bst_ulong len; bst_ulong len;
float *result; float *result;
int ret = XGDMatrixGetFloatInfo(handle, field, &len, (const float**) &result); int ret = XGDMatrixGetFloatInfo(handle, field, &len, (const float**) &result);
if (field) jenv->ReleaseStringUTFChars(jfield, field); if (field) jenv->ReleaseStringUTFChars(jfield, field);
jsize jlen = (jsize) len; jsize jlen = (jsize) len;
jfloatArray jarray = jenv->NewFloatArray(jlen); jfloatArray jarray = jenv->NewFloatArray(jlen);
jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) result); jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) result);
jenv->SetObjectArrayElement(jout, 0, (jobject) jarray); jenv->SetObjectArrayElement(jout, 0, (jobject) jarray);
return ret; return ret;
} }
/* /*
@ -236,18 +241,18 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetFloatInfo
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetUIntInfo JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetUIntInfo
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) {
DMatrixHandle handle = (DMatrixHandle) jhandle; DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* field = jenv->GetStringUTFChars(jfield, 0); const char* field = jenv->GetStringUTFChars(jfield, 0);
bst_ulong len; bst_ulong len;
unsigned int *result; unsigned int *result;
int ret = (jint) XGDMatrixGetUIntInfo(handle, field, &len, (const unsigned int **) &result); int ret = (jint) XGDMatrixGetUIntInfo(handle, field, &len, (const unsigned int **) &result);
if (field) jenv->ReleaseStringUTFChars(jfield, field); if (field) jenv->ReleaseStringUTFChars(jfield, field);
jsize jlen = (jsize) len; jsize jlen = (jsize) len;
jintArray jarray = jenv->NewIntArray(jlen); jintArray jarray = jenv->NewIntArray(jlen);
jenv->SetIntArrayRegion(jarray, 0, jlen, (jint *) result); jenv->SetIntArrayRegion(jarray, 0, jlen, (jint *) result);
jenv->SetObjectArrayElement(jout, 0, jarray); jenv->SetObjectArrayElement(jout, 0, jarray);
return ret; return ret;
} }
/* /*
@ -257,11 +262,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixGetUIntInfo
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixNumRow JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixNumRow
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlongArray jout) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jlongArray jout) {
DMatrixHandle handle = (DMatrixHandle) jhandle; DMatrixHandle handle = (DMatrixHandle) jhandle;
bst_ulong result[1]; bst_ulong result[1];
int ret = (jint) XGDMatrixNumRow(handle, result); int ret = (jint) XGDMatrixNumRow(handle, result);
jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) result); jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) result);
return ret; return ret;
} }
/* /*
@ -271,30 +276,29 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGDMatrixNumRow
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterCreate
(JNIEnv *jenv, jclass jcls, jlongArray jhandles, jlongArray jout) { (JNIEnv *jenv, jclass jcls, jlongArray jhandles, jlongArray jout) {
DMatrixHandle* handles; DMatrixHandle* handles = NULL;
bst_ulong len = 0; bst_ulong len = 0;
jlong* cjhandles = 0; jlong* cjhandles = 0;
BoosterHandle result; BoosterHandle result;
if(jhandles) { if (jhandles) {
len = (bst_ulong)jenv->GetArrayLength(jhandles); len = (bst_ulong)jenv->GetArrayLength(jhandles);
handles = new DMatrixHandle[len]; handles = new DMatrixHandle[len];
//put handle from jhandles to chandles //put handle from jhandles to chandles
cjhandles = jenv->GetLongArrayElements(jhandles, 0); cjhandles = jenv->GetLongArrayElements(jhandles, 0);
for(bst_ulong i=0; i<len; i++) { for(bst_ulong i=0; i<len; i++) {
handles[i] = (DMatrixHandle) cjhandles[i]; handles[i] = (DMatrixHandle) cjhandles[i];
}
} }
}
int ret = XGBoosterCreate(handles, len, &result);
//release int ret = XGBoosterCreate(handles, len, &result);
if(jhandles) { //release
delete[] handles; if (jhandles) {
jenv->ReleaseLongArrayElements(jhandles, cjhandles, 0); delete[] handles;
} jenv->ReleaseLongArrayElements(jhandles, cjhandles, 0);
setHandle(jenv, jout, result); }
setHandle(jenv, jout, result);
return ret; return ret;
} }
/* /*
@ -316,14 +320,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterFree
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSetParam JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSetParam
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jname, jstring jvalue) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jname, jstring jvalue) {
BoosterHandle handle = (BoosterHandle) jhandle; BoosterHandle handle = (BoosterHandle) jhandle;
const char* name = jenv->GetStringUTFChars(jname, 0); const char* name = jenv->GetStringUTFChars(jname, 0);
const char* value = jenv->GetStringUTFChars(jvalue, 0); const char* value = jenv->GetStringUTFChars(jvalue, 0);
int ret = XGBoosterSetParam(handle, name, value); int ret = XGBoosterSetParam(handle, name, value);
//release //release
if (name) jenv->ReleaseStringUTFChars(jname, name); if (name) jenv->ReleaseStringUTFChars(jname, name);
if (value) jenv->ReleaseStringUTFChars(jvalue, value); if (value) jenv->ReleaseStringUTFChars(jvalue, value);
return ret; return ret;
} }
/* /*
@ -333,9 +337,9 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSetParam
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterUpdateOneIter JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterUpdateOneIter
(JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlong jdtrain) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlong jdtrain) {
BoosterHandle handle = (BoosterHandle) jhandle; BoosterHandle handle = (BoosterHandle) jhandle;
DMatrixHandle dtrain = (DMatrixHandle) jdtrain; DMatrixHandle dtrain = (DMatrixHandle) jdtrain;
return XGBoosterUpdateOneIter(handle, jiter, dtrain); return XGBoosterUpdateOneIter(handle, jiter, dtrain);
} }
/* /*
@ -345,16 +349,16 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterUpdateOneIter
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterBoostOneIter JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterBoostOneIter
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jfloatArray jgrad, jfloatArray jhess) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jfloatArray jgrad, jfloatArray jhess) {
BoosterHandle handle = (BoosterHandle) jhandle; BoosterHandle handle = (BoosterHandle) jhandle;
DMatrixHandle dtrain = (DMatrixHandle) jdtrain; DMatrixHandle dtrain = (DMatrixHandle) jdtrain;
jfloat* grad = jenv->GetFloatArrayElements(jgrad, 0); jfloat* grad = jenv->GetFloatArrayElements(jgrad, 0);
jfloat* hess = jenv->GetFloatArrayElements(jhess, 0); jfloat* hess = jenv->GetFloatArrayElements(jhess, 0);
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jgrad); bst_ulong len = (bst_ulong)jenv->GetArrayLength(jgrad);
int ret = XGBoosterBoostOneIter(handle, dtrain, grad, hess, len); int ret = XGBoosterBoostOneIter(handle, dtrain, grad, hess, len);
//release //release
jenv->ReleaseFloatArrayElements(jgrad, grad, 0); jenv->ReleaseFloatArrayElements(jgrad, grad, 0);
jenv->ReleaseFloatArrayElements(jhess, hess, 0); jenv->ReleaseFloatArrayElements(jhess, hess, 0);
return ret; return ret;
} }
/* /*
@ -364,45 +368,45 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterBoostOneIter
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterEvalOneIter JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterEvalOneIter
(JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlongArray jdmats, jobjectArray jevnames, jobjectArray jout) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlongArray jdmats, jobjectArray jevnames, jobjectArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle; BoosterHandle handle = (BoosterHandle) jhandle;
DMatrixHandle* dmats = 0; DMatrixHandle* dmats = 0;
char **evnames = 0; char **evnames = 0;
char *result = 0; char *result = 0;
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jdmats); bst_ulong len = (bst_ulong)jenv->GetArrayLength(jdmats);
if(len > 0) { if(len > 0) {
dmats = new DMatrixHandle[len]; dmats = new DMatrixHandle[len];
evnames = new char*[len]; evnames = new char*[len];
} }
//put handle from jhandles to chandles //put handle from jhandles to chandles
jlong* cjdmats = jenv->GetLongArrayElements(jdmats, 0); jlong* cjdmats = jenv->GetLongArrayElements(jdmats, 0);
for(bst_ulong i=0; i<len; i++) {
dmats[i] = (DMatrixHandle) cjdmats[i];
}
//transfer jObjectArray to char**, user strcpy and release JNI char* inplace
for(bst_ulong i=0; i<len; i++) {
jstring jevname = (jstring)jenv->GetObjectArrayElement(jevnames, i);
const char* cevname = jenv->GetStringUTFChars(jevname, 0);
evnames[i] = new char[jenv->GetStringLength(jevname)];
strcpy(evnames[i], cevname);
jenv->ReleaseStringUTFChars(jevname, cevname);
}
int ret = XGBoosterEvalOneIter(handle, jiter, dmats, (char const *(*)) evnames, len, (const char **) &result);
if(len > 0) {
delete[] dmats;
//release string chars
for(bst_ulong i=0; i<len; i++) { for(bst_ulong i=0; i<len; i++) {
dmats[i] = (DMatrixHandle) cjdmats[i]; delete[] evnames[i];
} }
//transfer jObjectArray to char**, user strcpy and release JNI char* inplace delete[] evnames;
for(bst_ulong i=0; i<len; i++) { jenv->ReleaseLongArrayElements(jdmats, cjdmats, 0);
jstring jevname = (jstring)jenv->GetObjectArrayElement(jevnames, i); }
const char* cevname = jenv->GetStringUTFChars(jevname, 0);
evnames[i] = new char[jenv->GetStringLength(jevname)]; jstring jinfo = 0;
strcpy(evnames[i], cevname); if (result) jinfo = jenv->NewStringUTF((const char *) result);
jenv->ReleaseStringUTFChars(jevname, cevname); jenv->SetObjectArrayElement(jout, 0, jinfo);
}
return ret;
int ret = XGBoosterEvalOneIter(handle, jiter, dmats, (char const *(*)) evnames, len, (const char **) &result);
if(len > 0) {
delete[] dmats;
//release string chars
for(bst_ulong i=0; i<len; i++) {
delete[] evnames[i];
}
delete[] evnames;
jenv->ReleaseLongArrayElements(jdmats, cjdmats, 0);
}
jstring jinfo = 0;
if (result) jinfo = jenv->NewStringUTF((const char *) result);
jenv->SetObjectArrayElement(jout, 0, jinfo);
return ret;
} }
/* /*
@ -412,17 +416,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterEvalOneIter
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterPredict JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterPredict
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdmat, jint joption_mask, jint jntree_limit, jobjectArray jout) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdmat, jint joption_mask, jint jntree_limit, jobjectArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle; BoosterHandle handle = (BoosterHandle) jhandle;
DMatrixHandle dmat = (DMatrixHandle) jdmat; DMatrixHandle dmat = (DMatrixHandle) jdmat;
bst_ulong len; bst_ulong len;
float *result; float *result;
int ret = XGBoosterPredict(handle, dmat, joption_mask, (unsigned int) jntree_limit, &len, (const float **) &result); int ret = XGBoosterPredict(handle, dmat, joption_mask, (unsigned int) jntree_limit, &len, (const float **) &result);
jsize jlen = (jsize) len; jsize jlen = (jsize) len;
jfloatArray jarray = jenv->NewFloatArray(jlen); jfloatArray jarray = jenv->NewFloatArray(jlen);
jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) result); jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) result);
jenv->SetObjectArrayElement(jout, 0, jarray); jenv->SetObjectArrayElement(jout, 0, jarray);
return ret; return ret;
} }
/* /*
@ -432,12 +436,12 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterPredict
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModel JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModel
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) {
BoosterHandle handle = (BoosterHandle) jhandle; BoosterHandle handle = (BoosterHandle) jhandle;
const char* fname = jenv->GetStringUTFChars(jfname, 0); const char* fname = jenv->GetStringUTFChars(jfname, 0);
int ret = XGBoosterLoadModel(handle, fname); int ret = XGBoosterLoadModel(handle, fname);
if (fname) jenv->ReleaseStringUTFChars(jfname,fname); if (fname) jenv->ReleaseStringUTFChars(jfname,fname);
return ret; return ret;
} }
/* /*
@ -447,13 +451,13 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModel
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveModel JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveModel
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) {
BoosterHandle handle = (BoosterHandle) jhandle; BoosterHandle handle = (BoosterHandle) jhandle;
const char* fname = jenv->GetStringUTFChars(jfname, 0); const char* fname = jenv->GetStringUTFChars(jfname, 0);
int ret = XGBoosterSaveModel(handle, fname); int ret = XGBoosterSaveModel(handle, fname);
if (fname) jenv->ReleaseStringUTFChars(jfname, fname); if (fname) jenv->ReleaseStringUTFChars(jfname, fname);
return ret; return ret;
} }
/* /*
@ -463,9 +467,9 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveModel
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromBuffer JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromBuffer
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jbuf, jlong jlen) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jbuf, jlong jlen) {
BoosterHandle handle = (BoosterHandle) jhandle; BoosterHandle handle = (BoosterHandle) jhandle;
void *buf = (void*) jbuf; void *buf = (void*) jbuf;
return XGBoosterLoadModelFromBuffer(handle, (void const *)buf, (bst_ulong) jlen); return XGBoosterLoadModelFromBuffer(handle, (void const *)buf, (bst_ulong) jlen);
} }
/* /*
@ -475,16 +479,16 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadModelFromB
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
(JNIEnv * jenv, jclass jcls, jlong jhandle, jobjectArray jout) { (JNIEnv * jenv, jclass jcls, jlong jhandle, jobjectArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle; BoosterHandle handle = (BoosterHandle) jhandle;
bst_ulong len = 0; bst_ulong len = 0;
char *result; char *result;
int ret = XGBoosterGetModelRaw(handle, &len, (const char **) &result); int ret = XGBoosterGetModelRaw(handle, &len, (const char **) &result);
if (result){ if (result) {
jstring jinfo = jenv->NewStringUTF((const char *) result); jstring jinfo = jenv->NewStringUTF((const char *) result);
jenv->SetObjectArrayElement(jout, 0, jinfo); jenv->SetObjectArrayElement(jout, 0, jinfo);
} }
return ret; return ret;
} }
/* /*
@ -494,20 +498,129 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfmap, jint jwith_stats, jobjectArray jout) { (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfmap, jint jwith_stats, jobjectArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle; BoosterHandle handle = (BoosterHandle) jhandle;
const char *fmap = jenv->GetStringUTFChars(jfmap, 0); const char *fmap = jenv->GetStringUTFChars(jfmap, 0);
bst_ulong len = 0; bst_ulong len = 0;
char **result; char **result;
int ret = XGBoosterDumpModel(handle, fmap, jwith_stats, &len, (const char ***) &result); int ret = XGBoosterDumpModel(handle, fmap, jwith_stats, &len, (const char ***) &result);
jsize jlen = (jsize) len; jsize jlen = (jsize) len;
jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF("")); jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF(""));
for(int i=0 ; i<jlen; i++) { for(int i=0 ; i<jlen; i++) {
jenv->SetObjectArrayElement(jinfos, i, jenv->NewStringUTF((const char*) result[i])); jenv->SetObjectArrayElement(jinfos, i, jenv->NewStringUTF((const char*) result[i]));
} }
jenv->SetObjectArrayElement(jout, 0, jinfos); jenv->SetObjectArrayElement(jout, 0, jinfos);
if (fmap) jenv->ReleaseStringUTFChars(jfmap, (const char *)fmap); if (fmap) jenv->ReleaseStringUTFChars(jfmap, (const char *)fmap);
return ret; return ret;
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterLoadRabitCheckpoint
* Signature: (J[I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadRabitCheckpoint
(JNIEnv *jenv , jclass jcls, jlong jhandle, jintArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle;
int version;
int ret = XGBoosterLoadRabitCheckpoint(handle, &version);
jenv->SetIntArrayRegion(jout, 0, 1, &version);
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterSaveRabitCheckpoint
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveRabitCheckpoint
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
BoosterHandle handle = (BoosterHandle) jhandle;
return XGBoosterSaveRabitCheckpoint(handle);
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitInit
* Signature: ([Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitInit
(JNIEnv *jenv, jclass jcls, jobjectArray jargs) {
std::vector<std::string> args;
std::vector<char*> argv;
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jargs);
for (bst_ulong i = 0; i < len; ++i) {
jstring arg = (jstring)jenv->GetObjectArrayElement(jargs, i);
std::string s(jenv->GetStringUTFChars(arg, 0),
jenv->GetStringLength(arg));
if (s.length() != 0) args.push_back(s);
}
for (size_t i = 0; i < args.size(); ++i) {
argv.push_back(&args[i][0]);
}
RabitInit(args.size(), args.size() == 0 ? NULL : &argv[0]);
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitFinalize
* Signature: ()I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitFinalize
(JNIEnv *jenv, jclass jcls) {
RabitFinalize();
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitTrackerPrint
* Signature: (Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitTrackerPrint
(JNIEnv *jenv, jclass jcls, jstring jmsg) {
std::string str(jenv->GetStringUTFChars(jmsg, 0),
jenv->GetStringLength(jmsg));
RabitTrackerPrint(str.c_str());
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitGetRank
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetRank
(JNIEnv *jenv, jclass jcls, jintArray jout) {
int rank = RabitGetRank();
jenv->SetIntArrayRegion(jout, 0, 1, &rank);
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitGetWorldSize
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetWorldSize
(JNIEnv *jenv, jclass jcls, jintArray jout) {
int out = RabitGetWorldSize();
jenv->SetIntArrayRegion(jout, 0, 1, &out);
return 0;
}
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitVersionNumber
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitVersionNumber
(JNIEnv *jenv, jclass jcls, jintArray jout) {
int out = RabitVersionNumber();
jenv->SetIntArrayRegion(jout, 0, 1, &out);
return 0;
} }

View File

@ -215,6 +215,86 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetModelRaw
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterDumpModel
(JNIEnv *, jclass, jlong, jstring, jint, jobjectArray); (JNIEnv *, jclass, jlong, jstring, jint, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterGetAttr
* Signature: (JLjava/lang/String;[Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterGetAttr
(JNIEnv *, jclass, jlong, jstring, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterSetAttr
* Signature: (JLjava/lang/String;Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSetAttr
(JNIEnv *, jclass, jlong, jstring, jstring);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterLoadRabitCheckpoint
* Signature: (J[I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterLoadRabitCheckpoint
(JNIEnv *, jclass, jlong, jintArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: XGBoosterSaveRabitCheckpoint
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_XGBoosterSaveRabitCheckpoint
(JNIEnv *, jclass, jlong);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitInit
* Signature: ([Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitInit
(JNIEnv *, jclass, jobjectArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitFinalize
* Signature: ()I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitFinalize
(JNIEnv *, jclass);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitTrackerPrint
* Signature: (Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitTrackerPrint
(JNIEnv *, jclass, jstring);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitGetRank
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetRank
(JNIEnv *, jclass, jintArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitGetWorldSize
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitGetWorldSize
(JNIEnv *, jclass, jintArray);
/*
* Class: ml_dmlc_xgboost4j_XgboostJNI
* Method: RabitVersionNumber
* Signature: ([I)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_XgboostJNI_RabitVersionNumber
(JNIEnv *, jclass, jintArray);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif