[jvm-packages][xgboost4j-gpu] Support GPU dataframe and DeviceQuantileDMatrix (#7195)

Following classes are added to support dataframe in java binding:

- `Column` is an abstract type for a single column in tabular data.
- `ColumnBatch` is an abstract type for dataframe.

- `CuDFColumn` is an implementaiton of `Column` that consume cuDF column
- `CudfColumnBatch` is an implementation of `ColumnBatch` that consumes cuDF dataframe.

- `DeviceQuantileDMatrix` is the interface for quantized data.

The Java implementation mimics the Python interface and uses `__cuda_array_interface__` protocol for memory indexing.  One difference is on JVM package, the data batch is staged on the host as java iterators cannot be reset.

Co-authored-by: jiamingy <jm.yuan@outlook.com>
This commit is contained in:
Bobby Wang
2021-09-24 14:25:00 +08:00
committed by GitHub
parent d27a427dc5
commit 0ee11dac77
23 changed files with 1388 additions and 18 deletions

View File

@@ -335,6 +335,30 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
(JNIEnv *, jclass, jobject, jint, jint, jint);
/*
* Class: ml_dmlc_xgboost4j_java_GpuXGBoostJNI
* Method: XGDMatrixSetInfoFromInterface
* Signature: (JLjava/lang/String;Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFromInterface
(JNIEnv *, jclass, jlong, jstring, jstring);
/*
* Class: ml_dmlc_xgboost4j_java_GpuXGBoostJNI
* Method: XGDeviceQuantileDMatrixCreateFromCallback
* Signature: (Ljava/util/Iterator;FII[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDMatrixCreateFromCallback
(JNIEnv *, jclass, jobject, jfloat, jint, jint, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_java_GpuXGBoostJNI
* Method: XGDMatrixCreateFromArrayInterfaceColumns
* Signature: (Ljava/lang/String;FI[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromArrayInterfaceColumns
(JNIEnv *, jclass, jstring, jfloat, jint, jlongArray);
#ifdef __cplusplus
}
#endif