diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index 67b6a0ee4..cfab645ed 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -21,9 +21,11 @@ #include #include +#include // for copy_n #include #include #include +#include // for unique_ptr #include #include #include @@ -61,6 +63,11 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { return JNI_VERSION_1_6; } +namespace { +template +using Deleter = std::function; +} // anonymous namespace + XGB_EXTERN_C int XGBoost4jCallbackDataIterNext( DataIterHandle data_handle, XGBCallbackSetData* set_function, @@ -102,54 +109,70 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext( batch, jenv->GetFieldID(batchClass, "featureValue", "[F")); jint jcols = jenv->GetIntField( batch, jenv->GetFieldID(batchClass, "featureCols", "I")); - XGBoostBatchCSR cbatch; - cbatch.size = jenv->GetArrayLength(joffset) - 1; - cbatch.columns = jcols; - cbatch.offset = reinterpret_cast( - jenv->GetLongArrayElements(joffset, 0)); - if (jlabel != nullptr) { - cbatch.label = jenv->GetFloatArrayElements(jlabel, 0); - CHECK_EQ(jenv->GetArrayLength(jlabel), static_cast(cbatch.size)) - << "batch.label.length must equal batch.numRows()"; - } else { - cbatch.label = nullptr; - } - if (jweight != nullptr) { - cbatch.weight = jenv->GetFloatArrayElements(jweight, 0); - CHECK_EQ(jenv->GetArrayLength(jweight), static_cast(cbatch.size)) - << "batch.weight.length must equal batch.numRows()"; - } else { - cbatch.weight = nullptr; - } - long max_elem = cbatch.offset[cbatch.size]; - cbatch.index = (int*) jenv->GetIntArrayElements(jindex, 0); - cbatch.value = jenv->GetFloatArrayElements(jvalue, 0); - CHECK_EQ(jenv->GetArrayLength(jindex), max_elem) - << "batch.index.length must equal batch.offset.back()"; - CHECK_EQ(jenv->GetArrayLength(jvalue), max_elem) - << "batch.index.length must equal batch.offset.back()"; - // cbatch is ready - CHECK_EQ((*set_function)(set_function_handle, cbatch), 0) - << XGBGetLastError(); - // release the elements. - jenv->ReleaseLongArrayElements( - joffset, reinterpret_cast(cbatch.offset), 0); - jenv->DeleteLocalRef(joffset); - if (jlabel != nullptr) { - jenv->ReleaseFloatArrayElements(jlabel, cbatch.label, 0); - jenv->DeleteLocalRef(jlabel); - } - if (jweight != nullptr) { - jenv->ReleaseFloatArrayElements(jweight, cbatch.weight, 0); - jenv->DeleteLocalRef(jweight); - } - jenv->ReleaseIntArrayElements(jindex, (jint*) cbatch.index, 0); - jenv->DeleteLocalRef(jindex); - jenv->ReleaseFloatArrayElements(jvalue, cbatch.value, 0); - jenv->DeleteLocalRef(jvalue); + std::unique_ptr> cbatch{ + [&] { + auto ptr = new XGBoostBatchCSR; + auto &cbatch = *ptr; + + // Init + cbatch.size = jenv->GetArrayLength(joffset) - 1; + cbatch.columns = jcols; + cbatch.offset = reinterpret_cast(jenv->GetLongArrayElements(joffset, nullptr)); + + if (jlabel != nullptr) { + cbatch.label = jenv->GetFloatArrayElements(jlabel, nullptr); + CHECK_EQ(jenv->GetArrayLength(jlabel), static_cast(cbatch.size)) + << "batch.label.length must equal batch.numRows()"; + } else { + cbatch.label = nullptr; + } + + if (jweight != nullptr) { + cbatch.weight = jenv->GetFloatArrayElements(jweight, nullptr); + CHECK_EQ(jenv->GetArrayLength(jweight), static_cast(cbatch.size)) + << "batch.weight.length must equal batch.numRows()"; + } else { + cbatch.weight = nullptr; + } + + auto max_elem = cbatch.offset[cbatch.size]; + cbatch.index = (int *)jenv->GetIntArrayElements(jindex, nullptr); + cbatch.value = jenv->GetFloatArrayElements(jvalue, nullptr); + CHECK_EQ(jenv->GetArrayLength(jindex), max_elem) + << "batch.index.length must equal batch.offset.back()"; + CHECK_EQ(jenv->GetArrayLength(jvalue), max_elem) + << "batch.index.length must equal batch.offset.back()"; + return ptr; + }(), + [&](XGBoostBatchCSR *ptr) { + auto &cbatch = *ptr; + jenv->ReleaseLongArrayElements(joffset, reinterpret_cast(cbatch.offset), 0); + jenv->DeleteLocalRef(joffset); + + if (jlabel) { + jenv->ReleaseFloatArrayElements(jlabel, cbatch.label, 0); + jenv->DeleteLocalRef(jlabel); + } + if (jweight) { + jenv->ReleaseFloatArrayElements(jweight, cbatch.weight, 0); + jenv->DeleteLocalRef(jweight); + } + + jenv->ReleaseIntArrayElements(jindex, (jint *)cbatch.index, 0); + jenv->DeleteLocalRef(jindex); + + jenv->ReleaseFloatArrayElements(jvalue, cbatch.value, 0); + jenv->DeleteLocalRef(jvalue); + + delete ptr; + }}; + + CHECK_EQ((*set_function)(set_function_handle, *cbatch), 0) << XGBGetLastError(); + jenv->DeleteLocalRef(batch); jenv->DeleteLocalRef(batchClass); + ret_value = 1; } else { ret_value = 0; @@ -179,7 +202,7 @@ JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBGetLastError (JNIEnv *jenv, jclass jcls) { jstring jresult = 0; const char* result = XGBGetLastError(); - if (result != NULL) { + if (result) { jresult = jenv->NewStringUTF(result); } return jresult; @@ -193,16 +216,15 @@ JNIEXPORT jstring JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBGetLastError JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromDataIter (JNIEnv *jenv, jclass jcls, jobject jiter, jstring jcache_info, jlongArray jout) { DMatrixHandle result; - const char* cache_info = nullptr; + std::unique_ptr> cache_info; if (jcache_info != nullptr) { - cache_info = jenv->GetStringUTFChars(jcache_info, 0); + cache_info = {jenv->GetStringUTFChars(jcache_info, nullptr), [&](char const *ptr) { + jenv->ReleaseStringUTFChars(jcache_info, ptr); + }}; } - int ret = XGDMatrixCreateFromDataIter( - jiter, XGBoost4jCallbackDataIterNext, cache_info, &result); + int ret = + XGDMatrixCreateFromDataIter(jiter, XGBoost4jCallbackDataIterNext, cache_info.get(), &result); JVM_CHECK_CALL(ret); - if (cache_info) { - jenv->ReleaseStringUTFChars(jcache_info, cache_info); - } setHandle(jenv, jout, result); return ret; } @@ -212,20 +234,22 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro * Method: XGDMatrixCreateFromFile * Signature: (Ljava/lang/String;I[J)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromFile - (JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent, jlongArray jout) { +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromFile( + JNIEnv *jenv, jclass jcls, jstring jfname, jint jsilent, jlongArray jout) { + std::unique_ptr> fname{jenv->GetStringUTFChars(jfname, nullptr), + [&](char const *ptr) { + jenv->ReleaseStringUTFChars(jfname, ptr); + }}; DMatrixHandle result; - const char* fname = jenv->GetStringUTFChars(jfname, 0); - int ret = XGDMatrixCreateFromFile(fname, jsilent, &result); + int ret = XGDMatrixCreateFromFile(fname.get(), jsilent, &result); JVM_CHECK_CALL(ret); - if (fname) { - jenv->ReleaseStringUTFChars(jfname, fname); - } setHandle(jenv, jout, result); return ret; } namespace { +using JavaIndT = + std::conditional_t::value, std::int32_t, long>; /** * \brief Create from sparse matrix. * @@ -238,20 +262,28 @@ jint MakeJVMSparseInput(JNIEnv *jenv, jlongArray jindptr, jintArray jindices, jf jfloat jmissing, jint jnthread, Fn &&maker, jlongArray jout) { DMatrixHandle result; - jlong *indptr = jenv->GetLongArrayElements(jindptr, nullptr); - jint *indices = jenv->GetIntArrayElements(jindices, nullptr); - jfloat *data = jenv->GetFloatArrayElements(jdata, nullptr); + std::unique_ptr> indptr{jenv->GetLongArrayElements(jindptr, nullptr), + [&](jlong *ptr) { + jenv->ReleaseLongArrayElements(jindptr, ptr, 0); + }}; + std::unique_ptr> indices{jenv->GetIntArrayElements(jindices, nullptr), + [&](jint *ptr) { + jenv->ReleaseIntArrayElements(jindices, ptr, 0); + }}; + std::unique_ptr> data{jenv->GetFloatArrayElements(jdata, nullptr), + [&](jfloat *ptr) { + jenv->ReleaseFloatArrayElements(jdata, ptr, 0); + }}; + bst_ulong nindptr = static_cast(jenv->GetArrayLength(jindptr)); bst_ulong nelem = static_cast(jenv->GetArrayLength(jdata)); std::string sindptr, sindices, sdata; - CHECK_EQ(indptr[nindptr - 1], nelem); + CHECK_EQ(indptr.get()[nindptr - 1], nelem); using IndPtrT = std::conditional_t::value, long, long long>; - using IndT = - std::conditional_t::value, std::int32_t, long>; xgboost::detail::MakeSparseFromPtr( - static_cast(indptr), static_cast(indices), - static_cast(data), nindptr, &sindptr, &sindices, &sdata); + static_cast(indptr.get()), static_cast(indices.get()), + static_cast(data.get()), nindptr, &sindptr, &sindices, &sdata); xgboost::Json jconfig{xgboost::Object{}}; auto missing = static_cast(jmissing); @@ -265,11 +297,6 @@ jint MakeJVMSparseInput(JNIEnv *jenv, jlongArray jindptr, jintArray jindices, jf jint ret = maker(sindptr.c_str(), sindices.c_str(), sdata.c_str(), config.c_str(), &result); JVM_CHECK_CALL(ret); setHandle(jenv, jout, result); - - // Release - jenv->ReleaseLongArrayElements(jindptr, indptr, 0); - jenv->ReleaseIntArrayElements(jindices, indices, 0); - jenv->ReleaseFloatArrayElements(jdata, data, 0); return ret; } } // anonymous namespace @@ -335,37 +362,55 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromMat (JNIEnv *jenv, jclass jcls, jfloatArray jdata, jint jnrow, jint jncol, jfloat jmiss, jlongArray jout) { DMatrixHandle result; - jfloat* data = jenv->GetFloatArrayElements(jdata, 0); + std::unique_ptr> data{jenv->GetFloatArrayElements(jdata, 0), [&](jfloat* ptr) { + jenv->ReleaseFloatArrayElements(jdata, ptr, 0); + }}; + bst_ulong nrow = (bst_ulong)jnrow; bst_ulong ncol = (bst_ulong)jncol; - jint ret = (jint) XGDMatrixCreateFromMat((float const *)data, nrow, ncol, jmiss, &result); + jint ret = + XGDMatrixCreateFromMat(static_cast(data.get()), nrow, ncol, jmiss, &result); JVM_CHECK_CALL(ret); setHandle(jenv, jout, result); - //release - jenv->ReleaseFloatArrayElements(jdata, data, 0); return ret; } +namespace { +// Workaround int is not the same as jint. For some reason, if constexpr couldn't dispatch +// the following. +template +auto SliaceDMatrixWinWar(DMatrixHandle handle, T *ptr, std::size_t len, DMatrixHandle *result) { + // default to not allowing slicing with group ID specified -- feel free to add if necessary + return XGDMatrixSliceDMatrixEx(handle, ptr, len, result, 0); +} + +template <> +auto SliaceDMatrixWinWar(DMatrixHandle handle, long *ptr, std::size_t len, DMatrixHandle *result) { + std::vector copy(len); + std::copy_n(ptr, len, copy.begin()); + // default to not allowing slicing with group ID specified -- feel free to add if necessary + return XGDMatrixSliceDMatrixEx(handle, copy.data(), len, result, 0); +} +} // namespace + /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: XGDMatrixSliceDMatrix * Signature: (J[I)J */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSliceDMatrix - (JNIEnv *jenv, jclass jcls, jlong jhandle, jintArray jindexset, jlongArray jout) { +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSliceDMatrix( + JNIEnv *jenv, jclass jcls, jlong jhandle, jintArray jindexset, jlongArray jout) { DMatrixHandle result; - DMatrixHandle handle = (DMatrixHandle) jhandle; + auto handle = reinterpret_cast(jhandle); - jint* indexset = jenv->GetIntArrayElements(jindexset, 0); - bst_ulong len = (bst_ulong)jenv->GetArrayLength(jindexset); - - // default to not allowing slicing with group ID specified -- feel free to add if necessary - jint ret = (jint) XGDMatrixSliceDMatrixEx(handle, (int const *)indexset, len, &result, 0); + std::unique_ptr> indexset{jenv->GetIntArrayElements(jindexset, nullptr), + [&](jint *ptr) { + jenv->ReleaseIntArrayElements(jindexset, ptr, 0); + }}; + auto len = static_cast(jenv->GetArrayLength(jindexset)); + auto ret = SliaceDMatrixWinWar(handle, indexset.get(), len, &result); JVM_CHECK_CALL(ret); setHandle(jenv, jout, result); - //release - jenv->ReleaseIntArrayElements(jindexset, indexset, 0); - return ret; } @@ -386,13 +431,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixFree * Method: XGDMatrixSaveBinary * Signature: (JLjava/lang/String;I)V */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSaveBinary - (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname, jint jsilent) { - DMatrixHandle handle = (DMatrixHandle) jhandle; - const char* fname = jenv->GetStringUTFChars(jfname, 0); - int ret = XGDMatrixSaveBinary(handle, fname, jsilent); +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSaveBinary( + JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname, jint jsilent) { + DMatrixHandle handle = reinterpret_cast(jhandle); + std::unique_ptr> fname{ + jenv->GetStringUTFChars(jfname, nullptr), [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jfname, ptr); + } + }}; + int ret = XGDMatrixSaveBinary(handle, fname.get(), jsilent); JVM_CHECK_CALL(ret); - if (fname) jenv->ReleaseStringUTFChars(jfname, (const char *)fname); return ret; } @@ -401,20 +450,23 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSaveBinar * Method: XGDMatrixSetFloatInfo * Signature: (JLjava/lang/String;[F)V */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetFloatInfo - (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jfloatArray jarray) { - DMatrixHandle handle = (DMatrixHandle) jhandle; - const char* field = jenv->GetStringUTFChars(jfield, 0); +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetFloatInfo( + JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jfloatArray jarray) { + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> field{ + jenv->GetStringUTFChars(jfield, nullptr), [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jfield, ptr); + } + }}; + std::unique_ptr> array{jenv->GetFloatArrayElements(jarray, nullptr), + [&](jfloat *ptr) { + jenv->ReleaseFloatArrayElements(jarray, ptr, 0); + }}; - jfloat* array = jenv->GetFloatArrayElements(jarray, NULL); bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray); - auto str = xgboost::linalg::Make1dInterface(array, len); - int ret = XGDMatrixSetInfoFromInterface(handle, field, str.c_str()); - JVM_CHECK_CALL(ret); - //release - if (field) jenv->ReleaseStringUTFChars(jfield, field); - jenv->ReleaseFloatArrayElements(jarray, array, 0); - return ret; + auto str = xgboost::linalg::Make1dInterface(array.get(), len); + return XGDMatrixSetInfoFromInterface(handle, field.get(), str.c_str()); } /* @@ -424,18 +476,20 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetFloatI */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetUIntInfo (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jintArray jarray) { - DMatrixHandle handle = (DMatrixHandle) jhandle; - const char* field = jenv->GetStringUTFChars(jfield, 0); - jint* array = jenv->GetIntArrayElements(jarray, NULL); + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> field{ + jenv->GetStringUTFChars(jfield, nullptr), [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jfield, ptr); + } + }}; + std::unique_ptr> array{jenv->GetIntArrayElements(jarray, nullptr), + [&](jint *ptr) { + jenv->ReleaseIntArrayElements(jarray, ptr, 0); + }}; bst_ulong len = (bst_ulong)jenv->GetArrayLength(jarray); - auto str = xgboost::linalg::Make1dInterface(array, len); - int ret = XGDMatrixSetInfoFromInterface(handle, field, str.c_str()); - JVM_CHECK_CALL(ret); - //release - if (field) jenv->ReleaseStringUTFChars(jfield, (const char *)field); - jenv->ReleaseIntArrayElements(jarray, array, 0); - - return ret; + auto str = xgboost::linalg::Make1dInterface(array.get(), len); + return XGDMatrixSetInfoFromInterface(handle, field.get(), str.c_str()); } /* @@ -445,13 +499,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetUIntIn */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetFloatInfo (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) { - DMatrixHandle handle = (DMatrixHandle) jhandle; - const char* field = jenv->GetStringUTFChars(jfield, 0); + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> field{ + jenv->GetStringUTFChars(jfield, nullptr), [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jfield, ptr); + } + }}; bst_ulong len; float *result; - int ret = XGDMatrixGetFloatInfo(handle, field, &len, (const float**) &result); + int ret = XGDMatrixGetFloatInfo(handle, field.get(), &len, (const float**) &result); JVM_CHECK_CALL(ret); - if (field) jenv->ReleaseStringUTFChars(jfield, field); jsize jlen = (jsize) len; jfloatArray jarray = jenv->NewFloatArray(jlen); @@ -468,13 +526,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetFloatI */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetUIntInfo (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jobjectArray jout) { - DMatrixHandle handle = (DMatrixHandle) jhandle; - const char* field = jenv->GetStringUTFChars(jfield, 0); + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> field{ + jenv->GetStringUTFChars(jfield, nullptr), [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jfield, ptr); + } + }}; bst_ulong len; unsigned int *result; - int ret = (jint) XGDMatrixGetUIntInfo(handle, field, &len, (const unsigned int **) &result); + int ret = (jint)XGDMatrixGetUIntInfo(handle, field.get(), &len, (const unsigned int **)&result); JVM_CHECK_CALL(ret); - if (field) jenv->ReleaseStringUTFChars(jfield, field); jsize jlen = (jsize) len; jintArray jarray = jenv->NewIntArray(jlen); @@ -490,7 +552,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetUIntIn */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixNumRow (JNIEnv *jenv, jclass jcls, jlong jhandle, jlongArray jout) { - DMatrixHandle handle = (DMatrixHandle) jhandle; + auto handle = reinterpret_cast(jhandle); bst_ulong result[1]; int ret = (jint) XGDMatrixNumRow(handle, result); JVM_CHECK_CALL(ret); @@ -525,11 +587,13 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterCreate std::vector handles; if (jhandles != nullptr) { size_t len = jenv->GetArrayLength(jhandles); - jlong *cjhandles = jenv->GetLongArrayElements(jhandles, 0); + std::unique_ptr> cjhandles{ + jenv->GetLongArrayElements(jhandles, nullptr), [&](jlong *ptr) { + jenv->ReleaseLongArrayElements(jhandles, ptr, 0); + }}; for (size_t i = 0; i < len; ++i) { - handles.push_back((DMatrixHandle) cjhandles[i]); + handles.push_back(reinterpret_cast(cjhandles.get()[i])); } - jenv->ReleaseLongArrayElements(jhandles, cjhandles, 0); } BoosterHandle result; int ret = XGBoosterCreate(dmlc::BeginPtr(handles), handles.size(), &result); @@ -543,28 +607,35 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterCreate * Method: XGBoosterFree * Signature: (J)V */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterFree - (JNIEnv *jenv, jclass jcls, jlong jhandle) { - BoosterHandle handle = (BoosterHandle) jhandle; - return XGBoosterFree(handle); +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterFree(JNIEnv *jenv, + jclass jcls, + jlong jhandle) { + auto handle = reinterpret_cast(jhandle); + return XGBoosterFree(handle); } - /* * Class: ml_dmlc_xgboost4j_java_XGBoostJNI * Method: XGBoosterSetParam * Signature: (JLjava/lang/String;Ljava/lang/String;)V */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetParam - (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jname, jstring jvalue) { - BoosterHandle handle = (BoosterHandle) jhandle; - const char* name = jenv->GetStringUTFChars(jname, 0); - const char* value = jenv->GetStringUTFChars(jvalue, 0); - int ret = XGBoosterSetParam(handle, name, value); +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetParam( + JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jname, jstring jvalue) { + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> name{jenv->GetStringUTFChars(jname, nullptr), + [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jname, ptr); + } + }}; + std::unique_ptr> value{ + jenv->GetStringUTFChars(jvalue, nullptr), [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jvalue, ptr); + } + }}; + int ret = XGBoosterSetParam(handle, name.get(), value.get()); JVM_CHECK_CALL(ret); - //release - if (name) jenv->ReleaseStringUTFChars(jname, name); - if (value) jenv->ReleaseStringUTFChars(jvalue, value); return ret; } @@ -575,8 +646,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetParam */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterUpdateOneIter (JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlong jdtrain) { - BoosterHandle handle = (BoosterHandle) jhandle; - DMatrixHandle dtrain = (DMatrixHandle) jdtrain; + auto handle = reinterpret_cast(jhandle); + auto dtrain = reinterpret_cast(jdtrain); return XGBoosterUpdateOneIter(handle, jiter, dtrain); } @@ -589,16 +660,22 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterTrainOneI JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdtrain, jint jiter, jfloatArray jgrad, jfloatArray jhess) { API_BEGIN(); - BoosterHandle handle = reinterpret_cast(jhandle); - DMatrixHandle dtrain = reinterpret_cast(jdtrain); + auto handle = reinterpret_cast(jhandle); + auto dtrain = reinterpret_cast(jdtrain); CHECK(handle); CHECK(dtrain); bst_ulong n_samples{0}; JVM_CHECK_CALL(XGDMatrixNumRow(dtrain, &n_samples)); bst_ulong len = static_cast(jenv->GetArrayLength(jgrad)); - jfloat *grad = jenv->GetFloatArrayElements(jgrad, nullptr); - jfloat *hess = jenv->GetFloatArrayElements(jhess, nullptr); + std::unique_ptr> grad{jenv->GetFloatArrayElements(jgrad, nullptr), + [&](jfloat *ptr) { + jenv->ReleaseFloatArrayElements(jgrad, ptr, 0); + }}; + std::unique_ptr> hess{jenv->GetFloatArrayElements(jhess, nullptr), + [&](jfloat *ptr) { + jenv->ReleaseFloatArrayElements(jhess, ptr, 0); + }}; CHECK(grad); CHECK(hess); @@ -610,15 +687,9 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterTrainOneI auto ctx = xgboost::detail::BoosterCtx(handle); auto [s_grad, s_hess] = xgboost::detail::MakeGradientInterface( - ctx, grad, hess, xgboost::linalg::kC, n_samples, n_targets); - int ret = XGBoosterTrainOneIter(handle, dtrain, static_cast(jiter), s_grad.c_str(), - s_hess.c_str()); - - // release - jenv->ReleaseFloatArrayElements(jgrad, grad, 0); - jenv->ReleaseFloatArrayElements(jhess, hess, 0); - - return ret; + ctx, grad.get(), hess.get(), xgboost::linalg::kC, n_samples, n_targets); + return XGBoosterTrainOneIter(handle, dtrain, static_cast(jiter), s_grad.c_str(), + s_hess.c_str()); API_END(); } @@ -629,30 +700,33 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterTrainOneI */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterEvalOneIter (JNIEnv *jenv, jclass jcls, jlong jhandle, jint jiter, jlongArray jdmats, jobjectArray jevnames, jobjectArray jout) { - BoosterHandle handle = (BoosterHandle) jhandle; + auto handle = reinterpret_cast(jhandle); std::vector dmats; std::vector evnames; std::vector evchars; size_t len = static_cast(jenv->GetArrayLength(jdmats)); // put handle from jhandles to chandles - jlong* cjdmats = jenv->GetLongArrayElements(jdmats, 0); + std::unique_ptr> cjdmats{ + jenv->GetLongArrayElements(jdmats, nullptr), [&](jlong *ptr) { + jenv->ReleaseLongArrayElements(jdmats, ptr, 0); + }}; for (size_t i = 0; i < len; ++i) { - dmats.push_back((DMatrixHandle) cjdmats[i]); + dmats.push_back(reinterpret_cast(cjdmats.get()[i])); jstring jevname = (jstring)jenv->GetObjectArrayElement(jevnames, i); - const char *s =jenv->GetStringUTFChars(jevname, 0); - evnames.push_back(std::string(s, jenv->GetStringLength(jevname))); - if (s != nullptr) jenv->ReleaseStringUTFChars(jevname, s); + std::unique_ptr> s{jenv->GetStringUTFChars(jevname, nullptr), + [&](char const *ptr) { + jenv->ReleaseStringUTFChars(jevname, ptr); + }}; + evnames.emplace_back(s.get(), jenv->GetStringLength(jevname)); } - jenv->ReleaseLongArrayElements(jdmats, cjdmats, 0); + for (size_t i = 0; i < len; ++i) { evchars.push_back(evnames[i].c_str()); } - const char* result; - int ret = XGBoosterEvalOneIter(handle, jiter, - dmlc::BeginPtr(dmats), - dmlc::BeginPtr(evchars), - len, &result); + const char *result; + int ret = XGBoosterEvalOneIter(handle, jiter, dmlc::BeginPtr(dmats), dmlc::BeginPtr(evchars), len, + &result); JVM_CHECK_CALL(ret); jstring jinfo = nullptr; if (result != nullptr) { @@ -669,8 +743,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterEvalOneIt */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict (JNIEnv *jenv, jclass jcls, jlong jhandle, jlong jdmat, jint joption_mask, jint jntree_limit, jobjectArray jout) { - BoosterHandle handle = (BoosterHandle) jhandle; - DMatrixHandle dmat = (DMatrixHandle) jdmat; + auto handle = reinterpret_cast(jhandle); + auto dmat = reinterpret_cast(jdmat); bst_ulong len; float *result; int ret = XGBoosterPredict(handle, dmat, joption_mask, (unsigned int) jntree_limit, @@ -696,7 +770,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredictFr jfloat missing, jint iteration_begin, jint iteration_end, jint predict_type, jfloatArray jmargin, jobjectArray jout) { API_BEGIN(); - BoosterHandle handle = reinterpret_cast(jhandle); + auto handle = reinterpret_cast(jhandle); /** * Create array interface. @@ -770,17 +844,16 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredictFr * Method: XGBoosterLoadModel * Signature: (JLjava/lang/String;)V */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModel - (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) { - BoosterHandle handle = (BoosterHandle) jhandle; - const char* fname = jenv->GetStringUTFChars(jfname, 0); - - int ret = XGBoosterLoadModel(handle, fname); - JVM_CHECK_CALL(ret); - if (fname) { - jenv->ReleaseStringUTFChars(jfname,fname); - } - return ret; +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModel(JNIEnv *jenv, + jclass jcls, + jlong jhandle, + jstring jfname) { + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> fname{jenv->GetStringUTFChars(jfname, nullptr), + [&](char const *ptr) { + jenv->ReleaseStringUTFChars(jfname, ptr); + }}; + return XGBoosterLoadModel(handle, fname.get()); } /* @@ -788,17 +861,18 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModel * Method: XGBoosterSaveModel * Signature: (JLjava/lang/String;)V */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveModel - (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfname) { - BoosterHandle handle = (BoosterHandle) jhandle; - const char* fname = jenv->GetStringUTFChars(jfname, 0); - - int ret = XGBoosterSaveModel(handle, fname); - JVM_CHECK_CALL(ret); - if (fname) { - jenv->ReleaseStringUTFChars(jfname, fname); - } - return ret; +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveModel(JNIEnv *jenv, + jclass jcls, + jlong jhandle, + jstring jfname) { + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> fname{ + jenv->GetStringUTFChars(jfname, nullptr), [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jfname, ptr); + } + }}; + return XGBoosterSaveModel(handle, fname.get()); } /* @@ -806,15 +880,14 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveModel * Method: XGBoosterLoadModelFromBuffer * Signature: (J[B)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModelFromBuffer - (JNIEnv *jenv, jclass jcls, jlong jhandle, jbyteArray jbytes) { - BoosterHandle handle = (BoosterHandle) jhandle; - jbyte* buffer = jenv->GetByteArrayElements(jbytes, 0); - int ret = XGBoosterLoadModelFromBuffer( - handle, buffer, jenv->GetArrayLength(jbytes)); - JVM_CHECK_CALL(ret); - jenv->ReleaseByteArrayElements(jbytes, buffer, 0); - return ret; +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModelFromBuffer( + JNIEnv *jenv, jclass jcls, jlong jhandle, jbyteArray jbytes) { + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> buffer{jenv->GetByteArrayElements(jbytes, nullptr), + [&](jbyte *ptr) { + jenv->ReleaseByteArrayElements(jbytes, ptr, 0); + }}; + return XGBoosterLoadModelFromBuffer(handle, buffer.get(), jenv->GetArrayLength(jbytes)); } /* @@ -824,12 +897,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadModel */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveModelToBuffer (JNIEnv * jenv, jclass jcls, jlong jhandle, jstring jformat, jobjectArray jout) { - BoosterHandle handle = (BoosterHandle) jhandle; - const char *format = jenv->GetStringUTFChars(jformat, 0); + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> format{ + jenv->GetStringUTFChars(jformat, nullptr), [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jformat, ptr); + } + }}; bst_ulong len = 0; const char *result{nullptr}; - xgboost::Json config {xgboost::Object{}}; - config["format"] = std::string{format}; + xgboost::Json config{xgboost::Object{}}; + config["format"] = std::string{format.get()}; std::string config_str; xgboost::Json::Dump(config, &config_str); @@ -850,13 +928,23 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSaveModel */ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModelEx (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfmap, jint jwith_stats, jstring jformat, jobjectArray jout) { - BoosterHandle handle = (BoosterHandle) jhandle; - const char *fmap = jenv->GetStringUTFChars(jfmap, 0); - const char *format = jenv->GetStringUTFChars(jformat, 0); + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> fmap{jenv->GetStringUTFChars(jfmap, nullptr), + [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jfmap, ptr); + } + }}; + std::unique_ptr> format{ + jenv->GetStringUTFChars(jformat, nullptr), [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jformat, ptr); + } + }}; bst_ulong len = 0; - char **result; + char const **result; - int ret = XGBoosterDumpModelEx(handle, fmap, jwith_stats, format, &len, (const char ***) &result); + int ret = XGBoosterDumpModelEx(handle, fmap.get(), jwith_stats, format.get(), &len, &result); JVM_CHECK_CALL(ret); jsize jlen = (jsize) len; @@ -866,7 +954,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModel } jenv->SetObjectArrayElement(jout, 0, jinfos); - if (fmap) jenv->ReleaseStringUTFChars(jfmap, (const char *)fmap); return ret; } @@ -878,37 +965,48 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModel JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterDumpModelExWithFeatures (JNIEnv *jenv, jclass jcls, jlong jhandle, jobjectArray jfeature_names, jint jwith_stats, jstring jformat, jobjectArray jout) { - - BoosterHandle handle = (BoosterHandle) jhandle; + auto handle = reinterpret_cast(jhandle); bst_ulong feature_num = (bst_ulong)jenv->GetArrayLength(jfeature_names); std::vector feature_names; - std::vector feature_names_char; + std::vector feature_names_char; std::string feature_type_q = "q"; - std::vector feature_types_char; + std::vector feature_types_char; for (bst_ulong i = 0; i < feature_num; ++i) { jstring jfeature_name = (jstring)jenv->GetObjectArrayElement(jfeature_names, i); - const char *s = jenv->GetStringUTFChars(jfeature_name, 0); - feature_names.push_back(std::string(s, jenv->GetStringLength(jfeature_name))); - if (s != nullptr) jenv->ReleaseStringUTFChars(jfeature_name, s); - if (feature_names.back().length() == 0) feature_names.pop_back(); + std::unique_ptr> s{ + jenv->GetStringUTFChars(jfeature_name, nullptr), [&](char const *ptr) { + if (ptr != nullptr) { + jenv->ReleaseStringUTFChars(jfeature_name, ptr); + } + }}; + feature_names.emplace_back(s.get(), jenv->GetStringLength(jfeature_name)); + + if (feature_names.back().length() == 0) { + feature_names.pop_back(); + } } for (size_t i = 0; i < feature_names.size(); ++i) { - feature_names_char.push_back(&feature_names[i][0]); - feature_types_char.push_back(&feature_type_q[0]); + feature_names_char.push_back(feature_names[i].c_str()); + feature_types_char.push_back(feature_type_q.c_str()); } - const char *format = jenv->GetStringUTFChars(jformat, 0); + std::unique_ptr> format{ + jenv->GetStringUTFChars(jformat, nullptr), [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jformat, ptr); + } + }}; bst_ulong len = 0; char **result; - int ret = XGBoosterDumpModelExWithFeatures(handle, feature_num, - (const char **) dmlc::BeginPtr(feature_names_char), - (const char **) dmlc::BeginPtr(feature_types_char), - jwith_stats, format, &len, (const char ***) &result); + int ret = XGBoosterDumpModelExWithFeatures( + handle, feature_num, (const char **)dmlc::BeginPtr(feature_names_char), + (const char **)dmlc::BeginPtr(feature_types_char), jwith_stats, format.get(), &len, + (const char ***)&result); JVM_CHECK_CALL(ret); jsize jlen = (jsize) len; @@ -949,16 +1047,20 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttrNa * Method: XGBoosterGetAttr * Signature: (JLjava/lang/String;[Ljava/lang/String;)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttr - (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jkey, jobjectArray jout) { - BoosterHandle handle = (BoosterHandle) jhandle; - const char* key = jenv->GetStringUTFChars(jkey, 0); - const char* result; +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttr( + JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jkey, jobjectArray jout) { + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> key{jenv->GetStringUTFChars(jkey, nullptr), + [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jkey, ptr); + } + }}; + + const char *result; int success; - int ret = XGBoosterGetAttr(handle, key, &result, &success); + int ret = XGBoosterGetAttr(handle, key.get(), &result, &success); JVM_CHECK_CALL(ret); - //release - if (key) jenv->ReleaseStringUTFChars(jkey, key); if (success > 0) { jstring jret = jenv->NewStringUTF(result); @@ -973,17 +1075,22 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetAttr * Method: XGBoosterSetAttr * Signature: (JLjava/lang/String;Ljava/lang/String;)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr - (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jkey, jstring jvalue) { - BoosterHandle handle = (BoosterHandle) jhandle; - const char* key = jenv->GetStringUTFChars(jkey, 0); - const char* value = jenv->GetStringUTFChars(jvalue, 0); - int ret = XGBoosterSetAttr(handle, key, value); - JVM_CHECK_CALL(ret); - //release - if (key) jenv->ReleaseStringUTFChars(jkey, key); - if (value) jenv->ReleaseStringUTFChars(jvalue, value); - return ret; +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr( + JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jkey, jstring jvalue) { + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> key{jenv->GetStringUTFChars(jkey, nullptr), + [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jkey, ptr); + } + }}; + std::unique_ptr> value{ + jenv->GetStringUTFChars(jvalue, nullptr), [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jvalue, ptr); + } + }}; + return XGBoosterSetAttr(handle, key.get(), value.get()); } /* @@ -991,9 +1098,9 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetAttr * Method: XGBoosterGetNumFeature * Signature: (J[J)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFeature - (JNIEnv *jenv, jclass jcls, jlong jhandle, jlongArray jout) { - BoosterHandle handle = (BoosterHandle) jhandle; +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFeature( + JNIEnv *jenv, jclass jcls, jlong jhandle, jlongArray jout) { + auto handle = reinterpret_cast(jhandle); bst_ulong num_feature; int ret = XGBoosterGetNumFeature(handle, &num_feature); JVM_CHECK_CALL(ret); @@ -1004,7 +1111,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumFea JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetNumBoostedRound( JNIEnv *jenv, jclass, jlong jhandle, jintArray jout) { - BoosterHandle handle = (BoosterHandle)jhandle; + auto handle = reinterpret_cast(jhandle); std::int32_t n_rounds{0}; auto ret = XGBoosterBoostedRounds(handle, &n_rounds); JVM_CHECK_CALL(ret); @@ -1022,9 +1129,13 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorInit(J jclass jcls, jstring jargs) { xgboost::Json config{xgboost::Object{}}; - const char *args = jenv->GetStringUTFChars(jargs, nullptr); - JVM_CHECK_CALL(XGCommunicatorInit(args)); - return 0; + std::unique_ptr> args{jenv->GetStringUTFChars(jargs, nullptr), + [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jargs, ptr); + } + }}; + return XGCommunicatorInit(args.get()); } /* @@ -1039,7 +1150,11 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerCreate( TrackerHandle handle; Json config{Object{}}; - std::string shost{jenv->GetStringUTFChars(host, nullptr), + std::unique_ptr> p_shost{jenv->GetStringUTFChars(host, nullptr), + [&](char const *ptr) { + jenv->ReleaseStringUTFChars(host, ptr); + }}; + std::string shost{p_shost.get(), static_cast(jenv->GetStringLength(host))}; if (!shost.empty()) { config["host"] = shost; @@ -1136,12 +1251,17 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorFinali * Method: CommunicatorPrint * Signature: (Ljava/lang/String;)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorPrint - (JNIEnv *jenv, jclass jcls, jstring jmsg) { - std::string str(jenv->GetStringUTFChars(jmsg, 0), - jenv->GetStringLength(jmsg)); - JVM_CHECK_CALL(XGCommunicatorPrint(str.c_str())); - return 0; +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_CommunicatorPrint(JNIEnv *jenv, + jclass jcls, + jstring jmsg) { + std::unique_ptr> msg{jenv->GetStringUTFChars(jmsg, nullptr), + [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jmsg, ptr); + } + }}; + std::string str(msg.get(), jenv->GetStringLength(jmsg)); + return XGCommunicatorPrint(str.c_str()); } /* @@ -1210,11 +1330,15 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDM * Method: XGQuantileDMatrixCreateFromCallback * Signature: (Ljava/util/Iterator;Ljava/util/Iterator;Ljava/lang/String;[J)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGQuantileDMatrixCreateFromCallback - (JNIEnv *jenv, jclass jcls, jobject jdata_iter, jobject jref_iter, jstring jconf, jlongArray jout) { - char const *conf = jenv->GetStringUTFChars(jconf, 0); +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGQuantileDMatrixCreateFromCallback( + JNIEnv *jenv, jclass jcls, jobject jdata_iter, jobject jref_iter, jstring jconf, + jlongArray jout) { + std::unique_ptr> conf{jenv->GetStringUTFChars(jconf, nullptr), + [&](char const *ptr) { + jenv->ReleaseStringUTFChars(jconf, ptr); + }}; return xgboost::jni::XGQuantileDMatrixCreateFromCallbackImpl(jenv, jcls, jdata_iter, jref_iter, - conf, jout); + conf.get(), jout); } /* @@ -1222,18 +1346,19 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGQuantileDMatrixC * Method: XGDMatrixSetInfoFromInterface * Signature: (JLjava/lang/String;Ljava/lang/String;)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFromInterface - (JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jstring jjson_columns) { - DMatrixHandle handle = (DMatrixHandle) jhandle; - const char* field = jenv->GetStringUTFChars(jfield, 0); - const char* cjson_columns = jenv->GetStringUTFChars(jjson_columns, 0); +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFromInterface( + JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jstring jjson_columns) { + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> field{jenv->GetStringUTFChars(jfield, nullptr), + [&](char const *ptr) { + jenv->ReleaseStringUTFChars(jfield, ptr); + }}; + std::unique_ptr> cjson_columns{ + jenv->GetStringUTFChars(jjson_columns, nullptr), [&](char const *ptr) { + jenv->ReleaseStringUTFChars(jjson_columns, ptr); + }}; - int ret = XGDMatrixSetInfoFromInterface(handle, field, cjson_columns); - JVM_CHECK_CALL(ret); - //release - if (field) jenv->ReleaseStringUTFChars(jfield, field); - if (cjson_columns) jenv->ReleaseStringUTFChars(jjson_columns, cjson_columns); - return ret; + return XGDMatrixSetInfoFromInterface(handle, field.get(), cjson_columns.get()); } /* @@ -1244,7 +1369,10 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFr JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromArrayInterfaceColumns (JNIEnv *jenv, jclass jcls, jstring jjson_columns, jfloat jmissing, jint jnthread, jlongArray jout) { DMatrixHandle result; - const char* cjson_columns = jenv->GetStringUTFChars(jjson_columns, nullptr); + std::unique_ptr> cjson_columns{ + jenv->GetStringUTFChars(jjson_columns, nullptr), [&](char const *ptr) { + jenv->ReleaseStringUTFChars(jjson_columns, ptr); + }}; xgboost::Json config{xgboost::Object{}}; auto missing = static_cast(jmissing); auto n_threads = static_cast(jnthread); @@ -1252,43 +1380,38 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro config["nthread"] = xgboost::Integer(n_threads); std::string config_str; xgboost::Json::Dump(config, &config_str); - int ret = XGDMatrixCreateFromCudaColumnar(cjson_columns, config_str.c_str(), - &result); + int ret = XGDMatrixCreateFromCudaColumnar(cjson_columns.get(), config_str.c_str(), &result); JVM_CHECK_CALL(ret); - if (cjson_columns) { - jenv->ReleaseStringUTFChars(jjson_columns, cjson_columns); - } - setHandle(jenv, jout, result); return ret; } JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetStrFeatureInfo (JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield, jobjectArray jvalues) { - DMatrixHandle handle = (DMatrixHandle) jhandle; - const char* field = jenv->GetStringUTFChars(jfield, 0); + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> field{jenv->GetStringUTFChars(jfield, nullptr), + [&](char const *ptr) { + jenv->ReleaseStringUTFChars(jfield, ptr); + }}; int size = jenv->GetArrayLength(jvalues); // tmp storage for java strings std::vector values; for (int i = 0; i < size; i++) { jstring jstr = (jstring)(jenv->GetObjectArrayElement(jvalues, i)); - const char *value = jenv->GetStringUTFChars(jstr, 0); - values.emplace_back(value); - if (value) jenv->ReleaseStringUTFChars(jstr, value); + std::unique_ptr> value{jenv->GetStringUTFChars(jstr, nullptr), + [&](char const *ptr) { + jenv->ReleaseStringUTFChars(jstr, ptr); + }}; + values.emplace_back(value.get()); } - std::vector c_values; + std::vector c_values; c_values.resize(size); - std::transform(values.cbegin(), values.cend(), - c_values.begin(), + std::transform(values.cbegin(), values.cend(), c_values.begin(), [](auto const &str) { return str.c_str(); }); - int ret = XGDMatrixSetStrFeatureInfo(handle, field, c_values.data(), size); - JVM_CHECK_CALL(ret); - - if (field) jenv->ReleaseStringUTFChars(jfield, field); - return ret; + return XGDMatrixSetStrFeatureInfo(handle, field.get(), c_values.data(), size); } /* @@ -1296,28 +1419,29 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetStrFea * Method: XGDMatrixGetStrFeatureInfo * Signature: (JLjava/lang/String;[J[[Ljava/lang/String;)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetStrFeatureInfo - (JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield, jlongArray joutLenArray, - jobjectArray joutValueArray) { - DMatrixHandle handle = (DMatrixHandle) jhandle; - const char *field = jenv->GetStringUTFChars(jfield, 0); +JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetStrFeatureInfo( + JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield, jlongArray joutLenArray, + jobjectArray joutValueArray) { + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> field{jenv->GetStringUTFChars(jfield, nullptr), + [&](char const *ptr) { + jenv->ReleaseStringUTFChars(jfield, ptr); + }}; bst_ulong out_len = 0; char const **c_out_features; - int ret = XGDMatrixGetStrFeatureInfo(handle, field, &out_len, &c_out_features); + int ret = XGDMatrixGetStrFeatureInfo(handle, field.get(), &out_len, &c_out_features); - jlong jlen = (jlong) out_len; + jlong jlen = (jlong)out_len; jenv->SetLongArrayRegion(joutLenArray, 0, 1, &jlen); - jobjectArray jinfos = jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), - jenv->NewStringUTF("")); + jobjectArray jinfos = + jenv->NewObjectArray(jlen, jenv->FindClass("java/lang/String"), jenv->NewStringUTF("")); for (int i = 0; i < jlen; i++) { jenv->SetObjectArrayElement(jinfos, i, jenv->NewStringUTF(c_out_features[i])); } jenv->SetObjectArrayElement(joutValueArray, 0, jinfos); - JVM_CHECK_CALL(ret); - if (field) jenv->ReleaseStringUTFChars(jfield, field); return ret; } @@ -1330,10 +1454,12 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFeatureInfo( JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield, jobjectArray jfeatures) { - BoosterHandle handle = (BoosterHandle)jhandle; - - const char *field = jenv->GetStringUTFChars(jfield, 0); + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> field{jenv->GetStringUTFChars(jfield, nullptr), + [&](char const *ptr) { + jenv->ReleaseStringUTFChars(jfield, ptr); + }}; bst_ulong feature_num = (bst_ulong)jenv->GetArrayLength(jfeatures); std::vector features; @@ -1341,19 +1467,21 @@ Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterSetStrFeatureInfo( for (bst_ulong i = 0; i < feature_num; ++i) { jstring jfeature = (jstring)jenv->GetObjectArrayElement(jfeatures, i); - const char *s = jenv->GetStringUTFChars(jfeature, 0); - features.push_back(std::string(s, jenv->GetStringLength(jfeature))); - if (s != nullptr) jenv->ReleaseStringUTFChars(jfeature, s); + std::unique_ptr> s{ + jenv->GetStringUTFChars(jfeature, nullptr), [&](char const *ptr) { + if (ptr) { + jenv->ReleaseStringUTFChars(jfeature, ptr); + } + }}; + features.emplace_back(s.get(), jenv->GetStringLength(jfeature)); } for (size_t i = 0; i < features.size(); ++i) { features_char.push_back(features[i].c_str()); } - int ret = XGBoosterSetStrFeatureInfo( - handle, field, dmlc::BeginPtr(features_char), feature_num); - JVM_CHECK_CALL(ret); - return ret; + return XGBoosterSetStrFeatureInfo(handle, field.get(), dmlc::BeginPtr(features_char), + feature_num); } /* @@ -1365,17 +1493,19 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterGetStrFeatureInfo( JNIEnv *jenv, jclass jclz, jlong jhandle, jstring jfield, jobjectArray jout) { - BoosterHandle handle = (BoosterHandle)jhandle; - - const char *field = jenv->GetStringUTFChars(jfield, 0); + auto handle = reinterpret_cast(jhandle); + std::unique_ptr> field{jenv->GetStringUTFChars(jfield, nullptr), + [&](char const *ptr) { + jenv->ReleaseStringUTFChars(jfield, ptr); + }}; bst_ulong feature_num = (bst_ulong)jenv->GetArrayLength(jout); const char **features; std::vector features_char; - int ret = XGBoosterGetStrFeatureInfo(handle, field, &feature_num, - (const char ***)&features); + int ret = + XGBoosterGetStrFeatureInfo(handle, field.get(), &feature_num, (const char ***)&features); JVM_CHECK_CALL(ret); for (bst_ulong i = 0; i < feature_num; i++) { diff --git a/tests/cpp/collective/test_worker.h b/tests/cpp/collective/test_worker.h index 19e5e590a..78f4a28d8 100644 --- a/tests/cpp/collective/test_worker.h +++ b/tests/cpp/collective/test_worker.h @@ -13,8 +13,10 @@ #include // for vector #include "../../../src/collective/comm.h" -#include "../../../src/collective/tracker.h" // for GetHostAddress -#include "../helpers.h" // for FileExists +#include "../../../src/collective/communicator-inl.h" // for Init, Finalize +#include "../../../src/collective/tracker.h" // for GetHostAddress +#include "../../../src/common/common.h" // for AllVisibleGPUs +#include "../helpers.h" // for FileExists #if defined(XGBOOST_USE_FEDERATED) #include "../plugin/federated/test_worker.h" diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 6a89207e0..a47610636 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -12,9 +12,9 @@ #include #include -#include #include +#include "../../src/collective/communicator-inl.h" // for GetRank #include "../../src/data/adapter.h" #include "../../src/data/iterative_dmatrix.h" #include "../../src/data/simple_dmatrix.h" diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index 97f3db077..cb8852e1b 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -19,9 +19,11 @@ #include #include -#include "../../src/collective/communicator-inl.h" -#include "../../src/common/common.h" -#include "../../src/common/threading_utils.h" +#if defined(__CUDACC__) +#include "../../src/collective/communicator-inl.h" // for GetRank +#include "../../src/common/common.h" // for AllVisibleGPUs +#endif // defined(__CUDACC__) + #include "filesystem.h" // dmlc::TemporaryDirectory #include "xgboost/linalg.h" #if !defined(_OPENMP)