From 91fedd85b0c216c11ca6f36e2621bd381d49f7ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AD=90=E8=BD=A9?= Date: Tue, 29 Dec 2015 01:08:19 -0800 Subject: [PATCH] modify jni code --- java/xgboost4j_wrapper.cpp | 71 +++++++++++++++++++------------------- 1 file changed, 35 insertions(+), 36 deletions(-) diff --git a/java/xgboost4j_wrapper.cpp b/java/xgboost4j_wrapper.cpp index 04fbf3eed..d8ba5fb9b 100644 --- a/java/xgboost4j_wrapper.cpp +++ b/java/xgboost4j_wrapper.cpp @@ -14,6 +14,7 @@ #include "../wrapper/xgboost_wrapper.h" #include "xgboost4j_wrapper.h" +#include //helper functions //set handle @@ -215,14 +216,14 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetFl (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) { DMatrixHandle handle = (DMatrixHandle) jhandle; const char* field = jenv->GetStringUTFChars(jfield, 0); - bst_ulong len[1]; - float *result[1]; - int ret = XGDMatrixGetFloatInfo(handle, field, len, (const float **) result); + bst_ulong len; + float *result; + int ret = XGDMatrixGetFloatInfo(handle, field, &len, (const float**) &result); if (field) jenv->ReleaseStringUTFChars(jfield, field); - jsize jlen = (jsize) *len; + jsize jlen = (jsize) len; jfloatArray jarray = jenv->NewFloatArray(jlen); - jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) *result); + jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) result); jenv->SetObjectArrayElement(jout, 0, (jobject) jarray); return ret; @@ -237,15 +238,14 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetUI (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) { DMatrixHandle handle = (DMatrixHandle) jhandle; const char* field = jenv->GetStringUTFChars(jfield, 0); - bst_ulong len[1]; - *len = 0; - unsigned int *result[1]; - int ret = (jint) XGDMatrixGetUIntInfo(handle, field, len, (const unsigned int **) result); + bst_ulong len; + unsigned int *result; + int ret = (jint) XGDMatrixGetUIntInfo(handle, field, &len, (const unsigned int **) &result); if (field) jenv->ReleaseStringUTFChars(jfield, field); - jsize jlen = (jsize)*len; + jsize jlen = (jsize) len; jintArray jarray = jenv->NewIntArray(jlen); - jenv->SetIntArrayRegion(jarray, 0, jlen, (jint *) *result); + jenv->SetIntArrayRegion(jarray, 0, jlen, (jint *) result); jenv->SetObjectArrayElement(jout, 0, jarray); return ret; } @@ -367,7 +367,7 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEvalO BoosterHandle handle = (BoosterHandle) jhandle; DMatrixHandle* dmats = 0; char **evnames = 0; - char *result[1]; + char *result = 0; bst_ulong len = (bst_ulong)jenv->GetArrayLength(jdmats); if(len > 0) { dmats = new DMatrixHandle[len]; @@ -378,26 +378,28 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEvalO for(bst_ulong i=0; iGetObjectArrayElement(jevnames, i); - evnames[i] = (char *)jenv->GetStringUTFChars(jevname, 0); + const char* cevname = jenv->GetStringUTFChars(jevname, 0); + evnames[i] = new char[jenv->GetStringLength(jevname)]; + strcpy(evnames[i], cevname); + jenv->ReleaseStringUTFChars(jevname, cevname); } - int ret = XGBoosterEvalOneIter(handle, jiter, dmats, (char const *(*))evnames, len, (const char **) result); + int ret = XGBoosterEvalOneIter(handle, jiter, dmats, (char const *(*)) evnames, len, (const char **) &result); if(len > 0) { delete[] dmats; //release string chars for(bst_ulong i=0; iGetObjectArrayElement(jevnames, i); - jenv->ReleaseStringUTFChars(jevname, (const char*)evnames[i]); + delete[] evnames[i]; } delete[] evnames; jenv->ReleaseLongArrayElements(jdmats, cjdmats, 0); } jstring jinfo = 0; - if (*result) jinfo = jenv->NewStringUTF((const char *) *result); + if (result) jinfo = jenv->NewStringUTF((const char *) result); jenv->SetObjectArrayElement(jout, 0, jinfo); return ret; @@ -412,14 +414,13 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterPredi (JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdmat, jint joption_mask, jint jntree_limit, jobjectArray jout) { BoosterHandle handle = (BoosterHandle) jhandle; DMatrixHandle dmat = (DMatrixHandle) jdmat; - bst_ulong len[1]; - *len = 0; - float *result[1]; - int ret = XGBoosterPredict(handle, dmat, joption_mask, (unsigned int) jntree_limit, len, (const float **) result); + bst_ulong len; + float *result; + int ret = XGBoosterPredict(handle, dmat, joption_mask, (unsigned int) jntree_limit, &len, (const float **) &result); - jsize jlen = (jsize) *len; + jsize jlen = (jsize) len; jfloatArray jarray = jenv->NewFloatArray(jlen); - jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) *result); + jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) result); jenv->SetObjectArrayElement(jout, 0, jarray); return ret; } @@ -475,13 +476,12 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadM JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGetModelRaw (JNIEnv * jenv, jclass jcls, jlong jhandle, jobjectArray jout) { BoosterHandle handle = (BoosterHandle) jhandle; - bst_ulong len[1]; - *len = 0; - char *result[1]; + bst_ulong len = 0; + char *result; - int ret = XGBoosterGetModelRaw(handle, len, (const char **) result); - if (*result){ - jstring jinfo = jenv->NewStringUTF((const char *) *result); + int ret = XGBoosterGetModelRaw(handle, &len, (const char **) &result); + if (result){ + jstring jinfo = jenv->NewStringUTF((const char *) result); jenv->SetObjectArrayElement(jout, 0, jinfo); } return ret; @@ -496,16 +496,15 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterDumpM (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfmap, jint jwith_stats, jobjectArray jout) { BoosterHandle handle = (BoosterHandle) jhandle; const char *fmap = jenv->GetStringUTFChars(jfmap, 0); - bst_ulong len[1]; - *len = 0; - char **result[1]; + bst_ulong len = 0; + char **result; - int ret = XGBoosterDumpModel(handle, fmap, jwith_stats, len, (const char ***) result); + int ret = XGBoosterDumpModel(handle, fmap, jwith_stats, &len, (const char ***) &result); - jsize jlen = (jsize)*len; + jsize jlen = (jsize) len; jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF("")); for(int i=0 ; iSetObjectArrayElement(jinfos, i, jenv->NewStringUTF((const char*) result[0][i])); + jenv->SetObjectArrayElement(jinfos, i, jenv->NewStringUTF((const char*) result[i])); } jenv->SetObjectArrayElement(jout, 0, jinfos);