modify jni code
This commit is contained in:
parent
4a456b2a75
commit
91fedd85b0
@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
#include "../wrapper/xgboost_wrapper.h"
|
#include "../wrapper/xgboost_wrapper.h"
|
||||||
#include "xgboost4j_wrapper.h"
|
#include "xgboost4j_wrapper.h"
|
||||||
|
#include <cstring>
|
||||||
|
|
||||||
//helper functions
|
//helper functions
|
||||||
//set handle
|
//set handle
|
||||||
@ -215,14 +216,14 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetFl
|
|||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) {
|
||||||
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
||||||
const char* field = jenv->GetStringUTFChars(jfield, 0);
|
const char* field = jenv->GetStringUTFChars(jfield, 0);
|
||||||
bst_ulong len[1];
|
bst_ulong len;
|
||||||
float *result[1];
|
float *result;
|
||||||
int ret = XGDMatrixGetFloatInfo(handle, field, len, (const float **) result);
|
int ret = XGDMatrixGetFloatInfo(handle, field, &len, (const float**) &result);
|
||||||
if (field) jenv->ReleaseStringUTFChars(jfield, field);
|
if (field) jenv->ReleaseStringUTFChars(jfield, field);
|
||||||
|
|
||||||
jsize jlen = (jsize) *len;
|
jsize jlen = (jsize) len;
|
||||||
jfloatArray jarray = jenv->NewFloatArray(jlen);
|
jfloatArray jarray = jenv->NewFloatArray(jlen);
|
||||||
jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) *result);
|
jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) result);
|
||||||
jenv->SetObjectArrayElement(jout, 0, (jobject) jarray);
|
jenv->SetObjectArrayElement(jout, 0, (jobject) jarray);
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
@ -237,15 +238,14 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGDMatrixGetUI
|
|||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) {
|
||||||
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
||||||
const char* field = jenv->GetStringUTFChars(jfield, 0);
|
const char* field = jenv->GetStringUTFChars(jfield, 0);
|
||||||
bst_ulong len[1];
|
bst_ulong len;
|
||||||
*len = 0;
|
unsigned int *result;
|
||||||
unsigned int *result[1];
|
int ret = (jint) XGDMatrixGetUIntInfo(handle, field, &len, (const unsigned int **) &result);
|
||||||
int ret = (jint) XGDMatrixGetUIntInfo(handle, field, len, (const unsigned int **) result);
|
|
||||||
if (field) jenv->ReleaseStringUTFChars(jfield, field);
|
if (field) jenv->ReleaseStringUTFChars(jfield, field);
|
||||||
|
|
||||||
jsize jlen = (jsize)*len;
|
jsize jlen = (jsize) len;
|
||||||
jintArray jarray = jenv->NewIntArray(jlen);
|
jintArray jarray = jenv->NewIntArray(jlen);
|
||||||
jenv->SetIntArrayRegion(jarray, 0, jlen, (jint *) *result);
|
jenv->SetIntArrayRegion(jarray, 0, jlen, (jint *) result);
|
||||||
jenv->SetObjectArrayElement(jout, 0, jarray);
|
jenv->SetObjectArrayElement(jout, 0, jarray);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
@ -367,7 +367,7 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEvalO
|
|||||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||||
DMatrixHandle* dmats = 0;
|
DMatrixHandle* dmats = 0;
|
||||||
char **evnames = 0;
|
char **evnames = 0;
|
||||||
char *result[1];
|
char *result = 0;
|
||||||
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jdmats);
|
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jdmats);
|
||||||
if(len > 0) {
|
if(len > 0) {
|
||||||
dmats = new DMatrixHandle[len];
|
dmats = new DMatrixHandle[len];
|
||||||
@ -378,26 +378,28 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterEvalO
|
|||||||
for(bst_ulong i=0; i<len; i++) {
|
for(bst_ulong i=0; i<len; i++) {
|
||||||
dmats[i] = (DMatrixHandle) cjdmats[i];
|
dmats[i] = (DMatrixHandle) cjdmats[i];
|
||||||
}
|
}
|
||||||
//transfer jObjectArray to char**
|
//transfer jObjectArray to char**, user strcpy and release JNI char* inplace
|
||||||
for(bst_ulong i=0; i<len; i++) {
|
for(bst_ulong i=0; i<len; i++) {
|
||||||
jstring jevname = (jstring)jenv->GetObjectArrayElement(jevnames, i);
|
jstring jevname = (jstring)jenv->GetObjectArrayElement(jevnames, i);
|
||||||
evnames[i] = (char *)jenv->GetStringUTFChars(jevname, 0);
|
const char* cevname = jenv->GetStringUTFChars(jevname, 0);
|
||||||
|
evnames[i] = new char[jenv->GetStringLength(jevname)];
|
||||||
|
strcpy(evnames[i], cevname);
|
||||||
|
jenv->ReleaseStringUTFChars(jevname, cevname);
|
||||||
}
|
}
|
||||||
|
|
||||||
int ret = XGBoosterEvalOneIter(handle, jiter, dmats, (char const *(*))evnames, len, (const char **) result);
|
int ret = XGBoosterEvalOneIter(handle, jiter, dmats, (char const *(*)) evnames, len, (const char **) &result);
|
||||||
if(len > 0) {
|
if(len > 0) {
|
||||||
delete[] dmats;
|
delete[] dmats;
|
||||||
//release string chars
|
//release string chars
|
||||||
for(bst_ulong i=0; i<len; i++) {
|
for(bst_ulong i=0; i<len; i++) {
|
||||||
jstring jevname = (jstring)jenv->GetObjectArrayElement(jevnames, i);
|
delete[] evnames[i];
|
||||||
jenv->ReleaseStringUTFChars(jevname, (const char*)evnames[i]);
|
|
||||||
}
|
}
|
||||||
delete[] evnames;
|
delete[] evnames;
|
||||||
jenv->ReleaseLongArrayElements(jdmats, cjdmats, 0);
|
jenv->ReleaseLongArrayElements(jdmats, cjdmats, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
jstring jinfo = 0;
|
jstring jinfo = 0;
|
||||||
if (*result) jinfo = jenv->NewStringUTF((const char *) *result);
|
if (result) jinfo = jenv->NewStringUTF((const char *) result);
|
||||||
jenv->SetObjectArrayElement(jout, 0, jinfo);
|
jenv->SetObjectArrayElement(jout, 0, jinfo);
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
@ -412,14 +414,13 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterPredi
|
|||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdmat, jint joption_mask, jint jntree_limit, jobjectArray jout) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdmat, jint joption_mask, jint jntree_limit, jobjectArray jout) {
|
||||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||||
DMatrixHandle dmat = (DMatrixHandle) jdmat;
|
DMatrixHandle dmat = (DMatrixHandle) jdmat;
|
||||||
bst_ulong len[1];
|
bst_ulong len;
|
||||||
*len = 0;
|
float *result;
|
||||||
float *result[1];
|
int ret = XGBoosterPredict(handle, dmat, joption_mask, (unsigned int) jntree_limit, &len, (const float **) &result);
|
||||||
int ret = XGBoosterPredict(handle, dmat, joption_mask, (unsigned int) jntree_limit, len, (const float **) result);
|
|
||||||
|
|
||||||
jsize jlen = (jsize) *len;
|
jsize jlen = (jsize) len;
|
||||||
jfloatArray jarray = jenv->NewFloatArray(jlen);
|
jfloatArray jarray = jenv->NewFloatArray(jlen);
|
||||||
jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) *result);
|
jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) result);
|
||||||
jenv->SetObjectArrayElement(jout, 0, jarray);
|
jenv->SetObjectArrayElement(jout, 0, jarray);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
@ -475,13 +476,12 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterLoadM
|
|||||||
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGetModelRaw
|
JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterGetModelRaw
|
||||||
(JNIEnv * jenv, jclass jcls, jlong jhandle, jobjectArray jout) {
|
(JNIEnv * jenv, jclass jcls, jlong jhandle, jobjectArray jout) {
|
||||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||||
bst_ulong len[1];
|
bst_ulong len = 0;
|
||||||
*len = 0;
|
char *result;
|
||||||
char *result[1];
|
|
||||||
|
|
||||||
int ret = XGBoosterGetModelRaw(handle, len, (const char **) result);
|
int ret = XGBoosterGetModelRaw(handle, &len, (const char **) &result);
|
||||||
if (*result){
|
if (result){
|
||||||
jstring jinfo = jenv->NewStringUTF((const char *) *result);
|
jstring jinfo = jenv->NewStringUTF((const char *) result);
|
||||||
jenv->SetObjectArrayElement(jout, 0, jinfo);
|
jenv->SetObjectArrayElement(jout, 0, jinfo);
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
@ -496,16 +496,15 @@ JNIEXPORT jint JNICALL Java_org_dmlc_xgboost4j_wrapper_XgboostJNI_XGBoosterDumpM
|
|||||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfmap, jint jwith_stats, jobjectArray jout) {
|
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfmap, jint jwith_stats, jobjectArray jout) {
|
||||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||||
const char *fmap = jenv->GetStringUTFChars(jfmap, 0);
|
const char *fmap = jenv->GetStringUTFChars(jfmap, 0);
|
||||||
bst_ulong len[1];
|
bst_ulong len = 0;
|
||||||
*len = 0;
|
char **result;
|
||||||
char **result[1];
|
|
||||||
|
|
||||||
int ret = XGBoosterDumpModel(handle, fmap, jwith_stats, len, (const char ***) result);
|
int ret = XGBoosterDumpModel(handle, fmap, jwith_stats, &len, (const char ***) &result);
|
||||||
|
|
||||||
jsize jlen = (jsize)*len;
|
jsize jlen = (jsize) len;
|
||||||
jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF(""));
|
jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF(""));
|
||||||
for(int i=0 ; i<jlen; i++) {
|
for(int i=0 ; i<jlen; i++) {
|
||||||
jenv->SetObjectArrayElement(jinfos, i, jenv->NewStringUTF((const char*) result[0][i]));
|
jenv->SetObjectArrayElement(jinfos, i, jenv->NewStringUTF((const char*) result[i]));
|
||||||
}
|
}
|
||||||
jenv->SetObjectArrayElement(jout, 0, jinfos);
|
jenv->SetObjectArrayElement(jout, 0, jinfos);
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user