[jvm-packages] XGBoost4j Windows fixes (#1639)
* Changes for Mingw64 compilation to ensure long is a consistent size. Mainly impacts the Java API which would not compile, but there may be silent errors on Windows with large datasets before this patch (as long is 32-bits when compiled with mingw64 even in 64-bit mode). * Adding ifdefs to ensure it still compiles on MacOS * Makefile and create_jni.bat changes for Windows. * Switching XGDMatrixCreateFromCSREx JNI call to use size_t cast * Fixing lint error, adding profile switching to jvm-packages build to make create-jni.bat get called, adding myself to Contributors.Md
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <xgboost/c_api.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/logging.h>
|
||||
@@ -23,7 +24,11 @@
|
||||
// helper functions
|
||||
// set handle
|
||||
void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) {
|
||||
#ifdef __APPLE__
|
||||
long out = (long) handle;
|
||||
#else
|
||||
int64_t out = (int64_t) handle;
|
||||
#endif
|
||||
jenv->SetLongArrayRegion(jhandle, 0, 1, &out);
|
||||
}
|
||||
|
||||
@@ -87,7 +92,7 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
|
||||
cbatch.weight = nullptr;
|
||||
}
|
||||
long max_elem = cbatch.offset[cbatch.size];
|
||||
cbatch.index = jenv->GetIntArrayElements(jindex, 0);
|
||||
cbatch.index = (int*) jenv->GetIntArrayElements(jindex, 0);
|
||||
cbatch.value = jenv->GetFloatArrayElements(jvalue, 0);
|
||||
CHECK_EQ(jenv->GetArrayLength(jindex), max_elem)
|
||||
<< "batch.index.length must equal batch.offset.back()";
|
||||
@@ -107,7 +112,7 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
|
||||
jenv->ReleaseFloatArrayElements(jweight, cbatch.weight, 0);
|
||||
jenv->DeleteLocalRef(jweight);
|
||||
}
|
||||
jenv->ReleaseIntArrayElements(jindex, cbatch.index, 0);
|
||||
jenv->ReleaseIntArrayElements(jindex, (jint*) cbatch.index, 0);
|
||||
jenv->DeleteLocalRef(jindex);
|
||||
jenv->ReleaseFloatArrayElements(jvalue, cbatch.value, 0);
|
||||
jenv->DeleteLocalRef(jvalue);
|
||||
@@ -199,7 +204,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
|
||||
jfloat* data = jenv->GetFloatArrayElements(jdata, 0);
|
||||
bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
|
||||
bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata);
|
||||
int ret = (jint) XGDMatrixCreateFromCSREx((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, jcol, &result);
|
||||
jint ret = (jint) XGDMatrixCreateFromCSREx((size_t const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, jcol, &result);
|
||||
setHandle(jenv, jout, result);
|
||||
//Release
|
||||
jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
|
||||
@@ -222,7 +227,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
|
||||
bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
|
||||
bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata);
|
||||
|
||||
int ret = (jint) XGDMatrixCreateFromCSCEx((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, jrow, &result);
|
||||
jint ret = (jint) XGDMatrixCreateFromCSCEx((size_t const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, jrow, &result);
|
||||
setHandle(jenv, jout, result);
|
||||
//release
|
||||
jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
|
||||
@@ -244,7 +249,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
|
||||
jfloat* data = jenv->GetFloatArrayElements(jdata, 0);
|
||||
bst_ulong nrow = (bst_ulong)jnrow;
|
||||
bst_ulong ncol = (bst_ulong)jncol;
|
||||
int ret = (jint) XGDMatrixCreateFromMat((float const *)data, nrow, ncol, jmiss, &result);
|
||||
jint ret = (jint) XGDMatrixCreateFromMat((float const *)data, nrow, ncol, jmiss, &result);
|
||||
setHandle(jenv, jout, result);
|
||||
//release
|
||||
jenv->ReleaseFloatArrayElements(jdata, data, 0);
|
||||
@@ -264,7 +269,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSliceDMat
|
||||
jint* indexset = jenv->GetIntArrayElements(jindexset, 0);
|
||||
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jindexset);
|
||||
|
||||
int ret = XGDMatrixSliceDMatrix(handle, (int const *)indexset, len, &result);
|
||||
jint ret = (jint) XGDMatrixSliceDMatrix(handle, (int const *)indexset, len, &result);
|
||||
setHandle(jenv, jout, result);
|
||||
//release
|
||||
jenv->ReleaseIntArrayElements(jindexset, indexset, 0);
|
||||
@@ -650,7 +655,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabit
|
||||
BoosterHandle handle = (BoosterHandle) jhandle;
|
||||
int version;
|
||||
int ret = XGBoosterLoadRabitCheckpoint(handle, &version);
|
||||
jenv->SetIntArrayRegion(jout, 0, 1, &version);
|
||||
jint jversion = version;
|
||||
jenv->SetIntArrayRegion(jout, 0, 1, &jversion);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@@ -722,7 +728,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitTrackerPrint
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetRank
|
||||
(JNIEnv *jenv, jclass jcls, jintArray jout) {
|
||||
int rank = RabitGetRank();
|
||||
jint rank = RabitGetRank();
|
||||
jenv->SetIntArrayRegion(jout, 0, 1, &rank);
|
||||
return 0;
|
||||
}
|
||||
@@ -734,7 +740,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetRank
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetWorldSize
|
||||
(JNIEnv *jenv, jclass jcls, jintArray jout) {
|
||||
int out = RabitGetWorldSize();
|
||||
jint out = RabitGetWorldSize();
|
||||
jenv->SetIntArrayRegion(jout, 0, 1, &out);
|
||||
return 0;
|
||||
}
|
||||
@@ -746,7 +752,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetWorldSize
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
|
||||
(JNIEnv *jenv, jclass jcls, jintArray jout) {
|
||||
int out = RabitVersionNumber();
|
||||
jint out = RabitVersionNumber();
|
||||
jenv->SetIntArrayRegion(jout, 0, 1, &out);
|
||||
return 0;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user