[jvm-packages][xgboost4j-gpu] Support GPU dataframe and DeviceQuantileDMatrix (#7195)

Following classes are added to support dataframe in java binding:

- `Column` is an abstract type for a single column in tabular data.
- `ColumnBatch` is an abstract type for dataframe.

- `CuDFColumn` is an implementaiton of `Column` that consume cuDF column
- `CudfColumnBatch` is an implementation of `ColumnBatch` that consumes cuDF dataframe.

- `DeviceQuantileDMatrix` is the interface for quantized data.

The Java implementation mimics the Python interface and uses `__cuda_array_interface__` protocol for memory indexing.  One difference is on JVM package, the data batch is staged on the host as java iterators cannot be reset.

Co-authored-by: jiamingy <jm.yuan@outlook.com>
This commit is contained in:
Bobby Wang
2021-09-24 14:25:00 +08:00
committed by GitHub
parent d27a427dc5
commit 0ee11dac77
23 changed files with 1388 additions and 18 deletions

View File

@@ -19,6 +19,7 @@
#include <xgboost/c_api.h>
#include <xgboost/base.h>
#include <xgboost/logging.h>
#include <xgboost/json.h>
#include "./xgboost4j.h"
#include <cstring>
#include <vector>
@@ -43,12 +44,14 @@ void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) {
jenv->SetLongArrayRegion(jhandle, 0, 1, &out);
}
// global JVM
static JavaVM* global_jvm = nullptr;
JavaVM*& GlobalJvm() {
static JavaVM* vm;
return vm;
}
// overrides JNI on load
jint JNI_OnLoad(JavaVM *vm, void *reserved) {
global_jvm = vm;
GlobalJvm() = vm;
return JNI_VERSION_1_6;
}
@@ -58,9 +61,9 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
DataHolderHandle set_function_handle) {
jobject jiter = static_cast<jobject>(data_handle);
JNIEnv* jenv;
int jni_status = global_jvm->GetEnv((void **)&jenv, JNI_VERSION_1_6);
int jni_status = GlobalJvm()->GetEnv((void **)&jenv, JNI_VERSION_1_6);
if (jni_status == JNI_EDETACHED) {
global_jvm->AttachCurrentThread(reinterpret_cast<void **>(&jenv), nullptr);
GlobalJvm()->AttachCurrentThread(reinterpret_cast<void **>(&jenv), nullptr);
} else {
CHECK(jni_status == JNI_OK);
}
@@ -148,13 +151,13 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
jenv->DeleteLocalRef(iterClass);
// only detach if it is a async call.
if (jni_status == JNI_EDETACHED) {
global_jvm->DetachCurrentThread();
GlobalJvm()->DetachCurrentThread();
}
return ret_value;
} catch(dmlc::Error const& e) {
// only detach if it is a async call.
if (jni_status == JNI_EDETACHED) {
global_jvm->DetachCurrentThread();
GlobalJvm()->DetachCurrentThread();
}
LOG(FATAL) << e.what();
return -1;
@@ -968,3 +971,71 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
return 0;
}
namespace xgboost {
namespace jni {
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
jobject jiter,
jfloat jmissing,
jint jmax_bin, jint jnthread,
jlongArray jout);
} // namespace jni
} // namespace xgboost
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDeviceQuantileDMatrixCreateFromCallback
* Signature: (Ljava/util/Iterator;FII[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDMatrixCreateFromCallback
(JNIEnv *jenv, jclass jcls, jobject jiter, jfloat jmissing, jint jmax_bin,
jint jnthread, jlongArray jout) {
return xgboost::jni::XGDeviceQuantileDMatrixCreateFromCallbackImpl(
jenv, jcls, jiter, jmissing, jmax_bin, jnthread, jout);
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixSetInfoFromInterface
* Signature: (JLjava/lang/String;Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFromInterface
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jstring jjson_columns) {
DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* field = jenv->GetStringUTFChars(jfield, 0);
const char* cjson_columns = jenv->GetStringUTFChars(jjson_columns, 0);
int ret = XGDMatrixSetInfoFromInterface(handle, field, cjson_columns);
JVM_CHECK_CALL(ret);
//release
if (field) jenv->ReleaseStringUTFChars(jfield, field);
if (cjson_columns) jenv->ReleaseStringUTFChars(jjson_columns, cjson_columns);
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixCreateFromArrayInterfaceColumns
* Signature: (Ljava/lang/String;FI[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromArrayInterfaceColumns
(JNIEnv *jenv, jclass jcls, jstring jjson_columns, jfloat jmissing, jint jnthread, jlongArray jout) {
DMatrixHandle result;
const char* cjson_columns = jenv->GetStringUTFChars(jjson_columns, nullptr);
xgboost::Json config{xgboost::Object{}};
auto missing = static_cast<float>(jmissing);
auto n_threads = static_cast<int32_t>(jnthread);
config["missing"] = xgboost::Number(missing);
config["nthread"] = xgboost::Integer(n_threads);
std::string config_str;
xgboost::Json::Dump(config, &config_str);
int ret = XGDMatrixCreateFromCudaColumnar(cjson_columns, config_str.c_str(),
&result);
JVM_CHECK_CALL(ret);
if (cjson_columns) {
jenv->ReleaseStringUTFChars(jjson_columns, cjson_columns);
}
setHandle(jenv, jout, result);
return ret;
}