[jvm-packages][xgboost4j-gpu] Support GPU dataframe and DeviceQuantileDMatrix (#7195)
Following classes are added to support dataframe in java binding: - `Column` is an abstract type for a single column in tabular data. - `ColumnBatch` is an abstract type for dataframe. - `CuDFColumn` is an implementaiton of `Column` that consume cuDF column - `CudfColumnBatch` is an implementation of `ColumnBatch` that consumes cuDF dataframe. - `DeviceQuantileDMatrix` is the interface for quantized data. The Java implementation mimics the Python interface and uses `__cuda_array_interface__` protocol for memory indexing. One difference is on JVM package, the data batch is staged on the host as java iterators cannot be reset. Co-authored-by: jiamingy <jm.yuan@outlook.com>
This commit is contained in:
@@ -19,6 +19,7 @@
|
||||
#include <xgboost/c_api.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/json.h>
|
||||
#include "./xgboost4j.h"
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
@@ -43,12 +44,14 @@ void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) {
|
||||
jenv->SetLongArrayRegion(jhandle, 0, 1, &out);
|
||||
}
|
||||
|
||||
// global JVM
|
||||
static JavaVM* global_jvm = nullptr;
|
||||
JavaVM*& GlobalJvm() {
|
||||
static JavaVM* vm;
|
||||
return vm;
|
||||
}
|
||||
|
||||
// overrides JNI on load
|
||||
jint JNI_OnLoad(JavaVM *vm, void *reserved) {
|
||||
global_jvm = vm;
|
||||
GlobalJvm() = vm;
|
||||
return JNI_VERSION_1_6;
|
||||
}
|
||||
|
||||
@@ -58,9 +61,9 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
|
||||
DataHolderHandle set_function_handle) {
|
||||
jobject jiter = static_cast<jobject>(data_handle);
|
||||
JNIEnv* jenv;
|
||||
int jni_status = global_jvm->GetEnv((void **)&jenv, JNI_VERSION_1_6);
|
||||
int jni_status = GlobalJvm()->GetEnv((void **)&jenv, JNI_VERSION_1_6);
|
||||
if (jni_status == JNI_EDETACHED) {
|
||||
global_jvm->AttachCurrentThread(reinterpret_cast<void **>(&jenv), nullptr);
|
||||
GlobalJvm()->AttachCurrentThread(reinterpret_cast<void **>(&jenv), nullptr);
|
||||
} else {
|
||||
CHECK(jni_status == JNI_OK);
|
||||
}
|
||||
@@ -148,13 +151,13 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
|
||||
jenv->DeleteLocalRef(iterClass);
|
||||
// only detach if it is a async call.
|
||||
if (jni_status == JNI_EDETACHED) {
|
||||
global_jvm->DetachCurrentThread();
|
||||
GlobalJvm()->DetachCurrentThread();
|
||||
}
|
||||
return ret_value;
|
||||
} catch(dmlc::Error const& e) {
|
||||
// only detach if it is a async call.
|
||||
if (jni_status == JNI_EDETACHED) {
|
||||
global_jvm->DetachCurrentThread();
|
||||
GlobalJvm()->DetachCurrentThread();
|
||||
}
|
||||
LOG(FATAL) << e.what();
|
||||
return -1;
|
||||
@@ -968,3 +971,71 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
namespace xgboost {
|
||||
namespace jni {
|
||||
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
|
||||
jobject jiter,
|
||||
jfloat jmissing,
|
||||
jint jmax_bin, jint jnthread,
|
||||
jlongArray jout);
|
||||
} // namespace jni
|
||||
} // namespace xgboost
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDeviceQuantileDMatrixCreateFromCallback
|
||||
* Signature: (Ljava/util/Iterator;FII[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDMatrixCreateFromCallback
|
||||
(JNIEnv *jenv, jclass jcls, jobject jiter, jfloat jmissing, jint jmax_bin,
|
||||
jint jnthread, jlongArray jout) {
|
||||
return xgboost::jni::XGDeviceQuantileDMatrixCreateFromCallbackImpl(
|
||||
jenv, jcls, jiter, jmissing, jmax_bin, jnthread, jout);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixSetInfoFromInterface
|
||||
* Signature: (JLjava/lang/String;Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFromInterface
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jstring jjson_columns) {
|
||||
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
||||
const char* field = jenv->GetStringUTFChars(jfield, 0);
|
||||
const char* cjson_columns = jenv->GetStringUTFChars(jjson_columns, 0);
|
||||
|
||||
int ret = XGDMatrixSetInfoFromInterface(handle, field, cjson_columns);
|
||||
JVM_CHECK_CALL(ret);
|
||||
//release
|
||||
if (field) jenv->ReleaseStringUTFChars(jfield, field);
|
||||
if (cjson_columns) jenv->ReleaseStringUTFChars(jjson_columns, cjson_columns);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixCreateFromArrayInterfaceColumns
|
||||
* Signature: (Ljava/lang/String;FI[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromArrayInterfaceColumns
|
||||
(JNIEnv *jenv, jclass jcls, jstring jjson_columns, jfloat jmissing, jint jnthread, jlongArray jout) {
|
||||
DMatrixHandle result;
|
||||
const char* cjson_columns = jenv->GetStringUTFChars(jjson_columns, nullptr);
|
||||
xgboost::Json config{xgboost::Object{}};
|
||||
auto missing = static_cast<float>(jmissing);
|
||||
auto n_threads = static_cast<int32_t>(jnthread);
|
||||
config["missing"] = xgboost::Number(missing);
|
||||
config["nthread"] = xgboost::Integer(n_threads);
|
||||
std::string config_str;
|
||||
xgboost::Json::Dump(config, &config_str);
|
||||
int ret = XGDMatrixCreateFromCudaColumnar(cjson_columns, config_str.c_str(),
|
||||
&result);
|
||||
JVM_CHECK_CALL(ret);
|
||||
if (cjson_columns) {
|
||||
jenv->ReleaseStringUTFChars(jjson_columns, cjson_columns);
|
||||
}
|
||||
|
||||
setHandle(jenv, jout, result);
|
||||
return ret;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user