small change for jni wrapper

This commit is contained in:
yanqingmen 2015-12-20 22:13:53 +08:00
parent 7d23ea7e9e
commit 4a456b2a75

View File

@ -15,6 +15,14 @@
#include "../wrapper/xgboost_wrapper.h"
#include "xgboost4j_wrapper.h"
//helper functions
//set handle
void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) {
long out[1];
out[0] = (long) handle;
jenv->SetLongArrayRegion(jhandle, 0, 1, (const jlong*) out);
}
JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBGetLastError
(JNIEnv *jenv, jclass jcls) {
jstring jresult = 0 ;
@ -26,12 +34,10 @@ JNIEXPORT jstring JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBGetLastE
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromFile
(JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent, jlongArray jout) {
DMatrixHandle result;
long out[1];
const char* fname = jenv->GetStringUTFChars(jfname, 0);
int ret = XGDMatrixCreateFromFile(fname, jsilent, &result);
out[0] = (long) result;
if (fname) jenv->ReleaseStringUTFChars(jfname, fname);
jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) out);
setHandle(jenv, jout, result);
return ret;
}
@ -43,15 +49,13 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreat
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSR
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) {
DMatrixHandle result;
unsigned long out[1];
jlong* indptr = jenv->GetLongArrayElements(jindptr, 0);
jint* indices = jenv->GetIntArrayElements(jindices, 0);
jfloat* data = jenv->GetFloatArrayElements(jdata, 0);
bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata);
int ret = (jint) XGDMatrixCreateFromCSR((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, &result);
out[0] = (long)result;
jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) out);
setHandle(jenv, jout, result);
//Release
jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
jenv->ReleaseIntArrayElements(jindices, indices, 0);
@ -66,9 +70,7 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreat
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromCSC
(JNIEnv *jenv, jclass jcls, jlongArray jindptr, jintArray jindices, jfloatArray jdata, jlongArray jout) {
DMatrixHandle result;
unsigned long out[1];
DMatrixHandle result;
jlong* indptr = jenv->GetLongArrayElements(jindptr, NULL);
jint* indices = jenv->GetIntArrayElements(jindices, 0);
jfloat* data = jenv->GetFloatArrayElements(jdata, NULL);
@ -76,8 +78,7 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreat
bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata);
int ret = (jint) XGDMatrixCreateFromCSC((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, &result);
out[0] = (long) result;
jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) out);
setHandle(jenv, jout, result);
//release
jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
jenv->ReleaseIntArrayElements(jindices, indices, 0);
@ -94,13 +95,11 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreat
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreateFromMat
(JNIEnv *jenv, jclass jcls, jfloatArray jdata, jint jnrow, jint jncol, jfloat jmiss, jlongArray jout) {
DMatrixHandle result;
unsigned long out[1];
jfloat* data = jenv->GetFloatArrayElements(jdata, 0);
bst_ulong nrow = (bst_ulong)jnrow;
bst_ulong ncol = (bst_ulong)jncol;
int ret = (jint) XGDMatrixCreateFromMat((float const *)data, nrow, ncol, jmiss, &result);
out[0] = (long) result;
jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) out);
setHandle(jenv, jout, result);
//release
jenv->ReleaseFloatArrayElements(jdata, data, 0);
return ret;
@ -114,16 +113,13 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixCreat
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSliceDMatrix
(JNIEnv *jenv, jclass jcls, jlong jhandle, jintArray jindexset, jlongArray jout) {
DMatrixHandle result;
DMatrixHandle handle;
unsigned long out[1];
DMatrixHandle handle = (DMatrixHandle) jhandle;
jint* indexset = jenv->GetIntArrayElements(jindexset, 0);
handle = (DMatrixHandle) jhandle;
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jindexset);
int ret = XGDMatrixSliceDMatrix(handle, (int const *)indexset, len, &result);
out[0] = (long)result;
jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) out);
setHandle(jenv, jout, result);
//release
jenv->ReleaseIntArrayElements(jindexset, indexset, 0);
@ -137,8 +133,7 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSlice
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixFree
(JNIEnv *jenv, jclass jcls, jlong jhandle) {
DMatrixHandle handle;
handle = (DMatrixHandle) jhandle;
DMatrixHandle handle = (DMatrixHandle) jhandle;
int ret = XGDMatrixFree(handle);
return ret;
}
@ -150,8 +145,7 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixFree
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSaveBinary
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname, jint jsilent) {
DMatrixHandle handle;
handle = (DMatrixHandle) jhandle;
DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* fname = jenv->GetStringUTFChars(jfname, 0);
int ret = XGDMatrixSaveBinary(handle, fname, jsilent);
if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname);
@ -165,11 +159,9 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSaveB
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetFloatInfo
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jfloatArray jarray) {
DMatrixHandle handle;
handle = (DMatrixHandle) jhandle;
DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* field = jenv->GetStringUTFChars(jfield, 0);
jfloat* array = jenv->GetFloatArrayElements(jarray, NULL);
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray);
int ret = XGDMatrixSetFloatInfo(handle, field, (float const *)array, len);
@ -186,8 +178,7 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetFl
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetUIntInfo
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jintArray jarray) {
DMatrixHandle handle;
handle = (DMatrixHandle) jhandle;
DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* field = jenv->GetStringUTFChars(jfield, 0);
jint* array = jenv->GetIntArrayElements(jarray, NULL);
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray);
@ -206,8 +197,7 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetUI
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGroup
(JNIEnv * jenv, jclass jcls, jlong jhandle, jintArray jarray) {
DMatrixHandle handle;
handle = (DMatrixHandle) jhandle;
DMatrixHandle handle = (DMatrixHandle) jhandle;
jint* array = jenv->GetIntArrayElements(jarray, NULL);
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray);
int ret = XGDMatrixSetGroup(handle, (unsigned int const *)array, len);
@ -223,8 +213,7 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixSetGr
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetFloatInfo
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) {
DMatrixHandle handle;
handle = (DMatrixHandle) jhandle;
DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* field = jenv->GetStringUTFChars(jfield, 0);
bst_ulong len[1];
float *result[1];
@ -246,8 +235,7 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetFl
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetUIntInfo
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) {
DMatrixHandle handle;
handle = (DMatrixHandle) jhandle;
DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* field = jenv->GetStringUTFChars(jfield, 0);
bst_ulong len[1];
*len = 0;
@ -269,8 +257,7 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetUI
*/
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixNumRow
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlongArray jout) {
DMatrixHandle handle;
handle = (DMatrixHandle) jhandle;
DMatrixHandle handle = (DMatrixHandle) jhandle;
bst_ulong result[1];
int ret = (jint) XGDMatrixNumRow(handle, result);
jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) result);
@ -287,7 +274,6 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCreat
DMatrixHandle* handles;
bst_ulong len = 0;
jlong* cjhandles = 0;
unsigned long out[1];
BoosterHandle result;
if(jhandles) {
@ -306,8 +292,7 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterCreat
delete[] handles;
jenv->ReleaseLongArrayElements(jhandles, cjhandles, 0);
}
out[0] = (long) result;
jenv->SetLongArrayRegion(jout, 0, 1, (const jlong *) out);
setHandle(jenv, jout, result);
return ret;
}