[jvm-packages] XGBoost4j Windows fixes (#1639)

* Changes for Mingw64 compilation to ensure long is a consistent size.

Mainly impacts the Java API which would not compile, but there may be
silent errors on Windows with large datasets before this patch (as long
is 32-bits when compiled with mingw64 even in 64-bit mode).

* Adding ifdefs to ensure it still compiles on MacOS

* Makefile and create_jni.bat changes for Windows.

* Switching XGDMatrixCreateFromCSREx JNI call to use size_t cast

* Fixing lint error, adding profile switching to jvm-packages build to make create-jni.bat get called, adding myself to Contributors.Md
This commit is contained in:
Adam Pocock 2016-10-18 08:35:25 -04:00 committed by Nan Zhu
parent be90deb9b6
commit 445029bb82
7 changed files with 120 additions and 53 deletions

View File

@ -61,3 +61,4 @@ List of Contributors
* [Damien Carol](https://github.com/damiencarol) * [Damien Carol](https://github.com/damiencarol)
* [Alex Bain](https://github.com/convexquad) * [Alex Bain](https://github.com/convexquad)
* [Baltazar Bieniek](https://github.com/bbieniek) * [Baltazar Bieniek](https://github.com/bbieniek)
* [Adam Pocock](https://github.com/Craigacp)

View File

@ -62,6 +62,7 @@ ifneq ($(UNAME), Windows)
XGBOOST_DYLIB = lib/libxgboost.so XGBOOST_DYLIB = lib/libxgboost.so
else else
XGBOOST_DYLIB = lib/libxgboost.dll XGBOOST_DYLIB = lib/libxgboost.dll
JAVAINCFLAGS += -I${JAVA_HOME}/include/win32
endif endif
ifeq ($(UNAME), Linux) ifeq ($(UNAME), Linux)

View File

@ -47,7 +47,7 @@ namespace xgboost {
*/ */
typedef uint32_t bst_uint; typedef uint32_t bst_uint;
/*! \brief long integers */ /*! \brief long integers */
typedef unsigned long bst_ulong; // NOLINT(*) typedef uint64_t bst_ulong; // NOLINT(*)
/*! \brief float type, used for storing statistics */ /*! \brief float type, used for storing statistics */
typedef float bst_float; typedef float bst_float;

View File

@ -24,7 +24,7 @@ XGB_EXTERN_C {
#endif #endif
// manually define unsign long // manually define unsign long
typedef unsigned long bst_ulong; // NOLINT(*) typedef uint64_t bst_ulong; // NOLINT(*)
/*! \brief handle to DMatrix */ /*! \brief handle to DMatrix */
typedef void *DMatrixHandle; typedef void *DMatrixHandle;
@ -40,7 +40,13 @@ typedef struct {
/*! \brief number of rows in the minibatch */ /*! \brief number of rows in the minibatch */
size_t size; size_t size;
/*! \brief row pointer to the rows in the data */ /*! \brief row pointer to the rows in the data */
#ifdef __APPLE__
/* Necessary as Java on MacOS defines jlong as long int
* and gcc defines int64_t as long long int. */
long* offset; // NOLINT(*) long* offset; // NOLINT(*)
#else
int64_t* offset; // NOLINT(*)
#endif
/*! \brief labels of each instance */ /*! \brief labels of each instance */
float* label; float* label;
/*! \brief weight of each instance, can be NULL */ /*! \brief weight of each instance, can be NULL */

View File

@ -1,20 +1,19 @@
echo "move native library" echo "copy native library"
set libsource=..\windows\x64\Release\xgboost4j.dll set libsource=..\lib\libxgboost4j.so
if not exist %libsource% ( if not exist %libsource% (
goto end goto end
) )
set libfolder=xgboost4j\src\main\resources\lib set libfolder=src\main\resources\lib
set libpath=%libfolder%\xgboost4j.dll set libpath=%libfolder%\xgboost4j.dll
if not exist %libfolder% (mkdir %libfolder%) if not exist %libfolder% (mkdir %libfolder%)
if exist %libpath% (del %libpath%) if exist %libpath% (del %libpath%)
move %libsource% %libfolder% copy %libsource% %libpath%
echo complete echo complete
pause
exit exit
:end :end
echo "source library not found, please build it first from ..\windows\xgboost.sln" echo "source library not found, please build it first by runing mingw32-make jvm"
pause pause
exit exit

View File

@ -11,6 +11,14 @@
<artifactId>xgboost4j</artifactId> <artifactId>xgboost4j</artifactId>
<version>0.7</version> <version>0.7</version>
<packaging>jar</packaging> <packaging>jar</packaging>
<profiles>
<profile>
<id>NotWindows</id>
<activation>
<os>
<family>!windows</family>
</os>
</activation>
<build> <build>
<plugins> <plugins>
<plugin> <plugin>
@ -47,6 +55,52 @@
</plugin> </plugin>
</plugins> </plugins>
</build> </build>
</profile>
<profile>
<id>Windows</id>
<activation>
<os>
<family>windows</family>
</os>
</activation>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-javadoc-plugin</artifactId>
<version>2.10.3</version>
<configuration>
<show>protected</show>
<nohelp>true</nohelp>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
<configuration>
<skipAssembly>false</skipAssembly>
</configuration>
</plugin>
<plugin>
<artifactId>exec-maven-plugin</artifactId>
<groupId>org.codehaus.mojo</groupId>
<executions>
<execution><!-- Run our version calculation script -->
<id>native</id>
<phase>generate-sources</phase>
<goals>
<goal>exec</goal>
</goals>
<configuration>
<executable>create_jni.bat</executable>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</profile>
</profiles>
<dependencies> <dependencies>
<dependency> <dependency>
<groupId>junit</groupId> <groupId>junit</groupId>

View File

@ -12,6 +12,7 @@
limitations under the License. limitations under the License.
*/ */
#include <cstdint>
#include <xgboost/c_api.h> #include <xgboost/c_api.h>
#include <xgboost/base.h> #include <xgboost/base.h>
#include <xgboost/logging.h> #include <xgboost/logging.h>
@ -23,7 +24,11 @@
// helper functions // helper functions
// set handle // set handle
void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) { void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) {
#ifdef __APPLE__
long out = (long) handle; long out = (long) handle;
#else
int64_t out = (int64_t) handle;
#endif
jenv->SetLongArrayRegion(jhandle, 0, 1, &out); jenv->SetLongArrayRegion(jhandle, 0, 1, &out);
} }
@ -87,7 +92,7 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
cbatch.weight = nullptr; cbatch.weight = nullptr;
} }
long max_elem = cbatch.offset[cbatch.size]; long max_elem = cbatch.offset[cbatch.size];
cbatch.index = jenv->GetIntArrayElements(jindex, 0); cbatch.index = (int*) jenv->GetIntArrayElements(jindex, 0);
cbatch.value = jenv->GetFloatArrayElements(jvalue, 0); cbatch.value = jenv->GetFloatArrayElements(jvalue, 0);
CHECK_EQ(jenv->GetArrayLength(jindex), max_elem) CHECK_EQ(jenv->GetArrayLength(jindex), max_elem)
<< "batch.index.length must equal batch.offset.back()"; << "batch.index.length must equal batch.offset.back()";
@ -107,7 +112,7 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
jenv->ReleaseFloatArrayElements(jweight, cbatch.weight, 0); jenv->ReleaseFloatArrayElements(jweight, cbatch.weight, 0);
jenv->DeleteLocalRef(jweight); jenv->DeleteLocalRef(jweight);
} }
jenv->ReleaseIntArrayElements(jindex, cbatch.index, 0); jenv->ReleaseIntArrayElements(jindex, (jint*) cbatch.index, 0);
jenv->DeleteLocalRef(jindex); jenv->DeleteLocalRef(jindex);
jenv->ReleaseFloatArrayElements(jvalue, cbatch.value, 0); jenv->ReleaseFloatArrayElements(jvalue, cbatch.value, 0);
jenv->DeleteLocalRef(jvalue); jenv->DeleteLocalRef(jvalue);
@ -199,7 +204,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
jfloat* data = jenv->GetFloatArrayElements(jdata, 0); jfloat* data = jenv->GetFloatArrayElements(jdata, 0);
bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr); bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata); bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata);
int ret = (jint) XGDMatrixCreateFromCSREx((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, jcol, &result); jint ret = (jint) XGDMatrixCreateFromCSREx((size_t const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, jcol, &result);
setHandle(jenv, jout, result); setHandle(jenv, jout, result);
//Release //Release
jenv->ReleaseLongArrayElements(jindptr, indptr, 0); jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
@ -222,7 +227,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr); bst_ulong nindptr = (bst_ulong)jenv->GetArrayLength(jindptr);
bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata); bst_ulong nelem = (bst_ulong)jenv->GetArrayLength(jdata);
int ret = (jint) XGDMatrixCreateFromCSCEx((unsigned long const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, jrow, &result); jint ret = (jint) XGDMatrixCreateFromCSCEx((size_t const *)indptr, (unsigned int const *)indices, (float const *)data, nindptr, nelem, jrow, &result);
setHandle(jenv, jout, result); setHandle(jenv, jout, result);
//release //release
jenv->ReleaseLongArrayElements(jindptr, indptr, 0); jenv->ReleaseLongArrayElements(jindptr, indptr, 0);
@ -244,7 +249,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFro
jfloat* data = jenv->GetFloatArrayElements(jdata, 0); jfloat* data = jenv->GetFloatArrayElements(jdata, 0);
bst_ulong nrow = (bst_ulong)jnrow; bst_ulong nrow = (bst_ulong)jnrow;
bst_ulong ncol = (bst_ulong)jncol; bst_ulong ncol = (bst_ulong)jncol;
int ret = (jint) XGDMatrixCreateFromMat((float const *)data, nrow, ncol, jmiss, &result); jint ret = (jint) XGDMatrixCreateFromMat((float const *)data, nrow, ncol, jmiss, &result);
setHandle(jenv, jout, result); setHandle(jenv, jout, result);
//release //release
jenv->ReleaseFloatArrayElements(jdata, data, 0); jenv->ReleaseFloatArrayElements(jdata, data, 0);
@ -264,7 +269,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSliceDMat
jint* indexset = jenv->GetIntArrayElements(jindexset, 0); jint* indexset = jenv->GetIntArrayElements(jindexset, 0);
bst_ulong len = (bst_ulong)jenv->GetArrayLength(jindexset); bst_ulong len = (bst_ulong)jenv->GetArrayLength(jindexset);
int ret = XGDMatrixSliceDMatrix(handle, (int const *)indexset, len, &result); jint ret = (jint) XGDMatrixSliceDMatrix(handle, (int const *)indexset, len, &result);
setHandle(jenv, jout, result); setHandle(jenv, jout, result);
//release //release
jenv->ReleaseIntArrayElements(jindexset, indexset, 0); jenv->ReleaseIntArrayElements(jindexset, indexset, 0);
@ -650,7 +655,8 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterLoadRabit
BoosterHandle handle = (BoosterHandle) jhandle; BoosterHandle handle = (BoosterHandle) jhandle;
int version; int version;
int ret = XGBoosterLoadRabitCheckpoint(handle, &version); int ret = XGBoosterLoadRabitCheckpoint(handle, &version);
jenv->SetIntArrayRegion(jout, 0, 1, &version); jint jversion = version;
jenv->SetIntArrayRegion(jout, 0, 1, &jversion);
return ret; return ret;
} }
@ -722,7 +728,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitTrackerPrint
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetRank JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetRank
(JNIEnv *jenv, jclass jcls, jintArray jout) { (JNIEnv *jenv, jclass jcls, jintArray jout) {
int rank = RabitGetRank(); jint rank = RabitGetRank();
jenv->SetIntArrayRegion(jout, 0, 1, &rank); jenv->SetIntArrayRegion(jout, 0, 1, &rank);
return 0; return 0;
} }
@ -734,7 +740,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetRank
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetWorldSize JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetWorldSize
(JNIEnv *jenv, jclass jcls, jintArray jout) { (JNIEnv *jenv, jclass jcls, jintArray jout) {
int out = RabitGetWorldSize(); jint out = RabitGetWorldSize();
jenv->SetIntArrayRegion(jout, 0, 1, &out); jenv->SetIntArrayRegion(jout, 0, 1, &out);
return 0; return 0;
} }
@ -746,7 +752,7 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitGetWorldSize
*/ */
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
(JNIEnv *jenv, jclass jcls, jintArray jout) { (JNIEnv *jenv, jclass jcls, jintArray jout) {
int out = RabitVersionNumber(); jint out = RabitVersionNumber();
jenv->SetIntArrayRegion(jout, 0, 1, &out); jenv->SetIntArrayRegion(jout, 0, 1, &out);
return 0; return 0;
} }