[jvm-packages][xgboost4j-gpu] Support GPU dataframe and DeviceQuantileDMatrix (#7195)

Following classes are added to support dataframe in java binding:

- `Column` is an abstract type for a single column in tabular data.
- `ColumnBatch` is an abstract type for dataframe.

- `CuDFColumn` is an implementaiton of `Column` that consume cuDF column
- `CudfColumnBatch` is an implementation of `ColumnBatch` that consumes cuDF dataframe.

- `DeviceQuantileDMatrix` is the interface for quantized data.

The Java implementation mimics the Python interface and uses `__cuda_array_interface__` protocol for memory indexing.  One difference is on JVM package, the data batch is staged on the host as java iterators cannot be reset.

Co-authored-by: jiamingy <jm.yuan@outlook.com>
This commit is contained in:
Bobby Wang
2021-09-24 14:25:00 +08:00
committed by GitHub
parent d27a427dc5
commit 0ee11dac77
23 changed files with 1388 additions and 18 deletions

View File

@@ -0,0 +1,40 @@
/*
Copyright (c) 2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java;
/**
* The abstracted XGBoost Column to get the cuda array interface which is used to
* set the information for DMatrix.
*
*/
public abstract class Column implements AutoCloseable {
/**
* Get the cuda array interface json string for the Column which can be representing
* weight, label, base margin column.
*
* This API will be called by
* {@link DMatrix#setLabel(Column)}
* {@link DMatrix#setWeight(Column)}
* {@link DMatrix#setBaseMargin(Column)}
*/
public abstract String getArrayInterfaceJson();
@Override
public void close() throws Exception {}
}

View File

@@ -0,0 +1,93 @@
/*
Copyright (c) 2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java;
import java.util.Iterator;
/**
* The abstracted XGBoost ColumnBatch to get array interface from columnar data format.
* For example, the cuDF dataframe which employs apache arrow specification.
*/
public abstract class ColumnBatch implements AutoCloseable {
/**
* Get the cuda array interface json string for the whole ColumnBatch including
* the must-have feature, label columns and the optional weight, base margin columns.
*
* This function is be called by native code during iteration and can be made as private
* method. We keep it as public simply to silent the linter.
*/
public final String getArrayInterfaceJson() {
StringBuilder builder = new StringBuilder();
builder.append("{");
String featureStr = this.getFeatureArrayInterface();
if (featureStr == null || featureStr.isEmpty()) {
throw new RuntimeException("Feature array interface must not be empty");
} else {
builder.append("\"features_str\":" + featureStr);
}
String labelStr = this.getLabelsArrayInterface();
if (labelStr == null || labelStr.isEmpty()) {
throw new RuntimeException("Label array interface must not be empty");
} else {
builder.append(",\"label_str\":" + labelStr);
}
String weightStr = getWeightsArrayInterface();
if (weightStr != null && ! weightStr.isEmpty()) {
builder.append(",\"weight_str\":" + weightStr);
}
String baseMarginStr = getBaseMarginsArrayInterface();
if (baseMarginStr != null && ! baseMarginStr.isEmpty()) {
builder.append(",\"basemargin_str\":" + baseMarginStr);
}
builder.append("}");
return builder.toString();
}
/**
* Get the cuda array interface of the feature columns.
* The returned value must not be null or empty
*/
public abstract String getFeatureArrayInterface();
/**
* Get the cuda array interface of the label columns.
* The returned value must not be null or empty if we're creating
* {@link DeviceQuantileDMatrix#DeviceQuantileDMatrix(Iterator, float, int, int)}
*/
public abstract String getLabelsArrayInterface();
/**
* Get the cuda array interface of the weight columns.
* The returned value can be null or empty
*/
public abstract String getWeightsArrayInterface();
/**
* Get the cuda array interface of the base margin columns.
* The returned value can be null or empty
*/
public abstract String getBaseMarginsArrayInterface();
@Override
public void close() throws Exception {}
}

View File

@@ -177,6 +177,64 @@ public class DMatrix {
this.handle = handle;
}
/**
* Create the normal DMatrix from column array interface
* @param columnBatch the XGBoost ColumnBatch to provide the cuda array interface
* of feature columns
* @param missing missing value
* @param nthread threads number
* @throws XGBoostError
*/
public DMatrix(ColumnBatch columnBatch, float missing, int nthread) throws XGBoostError {
long[] out = new long[1];
String json = columnBatch.getFeatureArrayInterface();
if (json == null || json.isEmpty()) {
throw new XGBoostError("Expecting non-empty feature columns' array interface");
}
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromArrayInterfaceColumns(
json, missing, nthread, out));
handle = out[0];
}
/**
* Set label of DMatrix from cuda array interface
*
* @param column the XGBoost Column to provide the cuda array interface
* of label column
* @throws XGBoostError native error
*/
public void setLabel(Column column) throws XGBoostError {
setXGBDMatrixInfo("label", column.getArrayInterfaceJson());
}
/**
* Set weight of DMatrix from cuda array interface
*
* @param column the XGBoost Column to provide the cuda array interface
* of weight column
* @throws XGBoostError native error
*/
public void setWeight(Column column) throws XGBoostError {
setXGBDMatrixInfo("weight", column.getArrayInterfaceJson());
}
/**
* Set base margin of DMatrix from cuda array interface
*
* @param column the XGBoost Column to provide the cuda array interface
* of base margin column
* @throws XGBoostError native error
*/
public void setBaseMargin(Column column) throws XGBoostError {
setXGBDMatrixInfo("base_margin", column.getArrayInterfaceJson());
}
private void setXGBDMatrixInfo(String type, String json) throws XGBoostError {
if (json == null || json.isEmpty()) {
throw new XGBoostError("Empty " + type + " columns' array interface");
}
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetInfoFromInterface(handle, type, json));
}
/**
* set label of dmatrix

View File

@@ -0,0 +1,68 @@
package ml.dmlc.xgboost4j.java;
import java.util.Iterator;
/**
* DeviceQuantileDMatrix will only be used to train
*/
public class DeviceQuantileDMatrix extends DMatrix {
/**
* Create DeviceQuantileDMatrix from iterator based on the cuda array interface
* @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array interface
* @param missing the missing value
* @param maxBin the max bin
* @param nthread the parallelism
* @throws XGBoostError
*/
public DeviceQuantileDMatrix(
Iterator<ColumnBatch> iter,
float missing,
int maxBin,
int nthread) throws XGBoostError {
super(0);
long[] out = new long[1];
XGBoostJNI.checkCall(XGBoostJNI.XGDeviceQuantileDMatrixCreateFromCallback(
iter, missing, maxBin, nthread, out));
handle = out[0];
}
@Override
public void setLabel(Column column) throws XGBoostError {
throw new XGBoostError("DeviceQuantileDMatrix does not support setLabel.");
}
@Override
public void setWeight(Column column) throws XGBoostError {
throw new XGBoostError("DeviceQuantileDMatrix does not support setWeight.");
}
@Override
public void setBaseMargin(Column column) throws XGBoostError {
throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.");
}
@Override
public void setLabel(float[] labels) throws XGBoostError {
throw new XGBoostError("DeviceQuantileDMatrix does not support setLabel.");
}
@Override
public void setWeight(float[] weights) throws XGBoostError {
throw new XGBoostError("DeviceQuantileDMatrix does not support setWeight.");
}
@Override
public void setBaseMargin(float[] baseMargin) throws XGBoostError {
throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.");
}
@Override
public void setBaseMargin(float[][] baseMargin) throws XGBoostError {
throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.");
}
@Override
public void setGroup(int[] group) throws XGBoostError {
throw new XGBoostError("DeviceQuantileDMatrix does not support setGroup.");
}
}

View File

@@ -134,4 +134,13 @@ class XGBoostJNI {
// This JNI function does not support the callback function for data preparation yet.
final static native int RabitAllreduce(ByteBuffer sendrecvbuf, int count,
int enum_dtype, int enum_op);
public final static native int XGDMatrixSetInfoFromInterface(
long handle, String field, String json);
public final static native int XGDeviceQuantileDMatrixCreateFromCallback(
java.util.Iterator<ColumnBatch> iter, float missing, int nthread, int maxBin, long[] out);
public final static native int XGDMatrixCreateFromArrayInterfaceColumns(
String featureJson, float missing, int nthread, long[] out);
}

View File

@@ -19,6 +19,7 @@
#include <xgboost/c_api.h>
#include <xgboost/base.h>
#include <xgboost/logging.h>
#include <xgboost/json.h>
#include "./xgboost4j.h"
#include <cstring>
#include <vector>
@@ -43,12 +44,14 @@ void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) {
jenv->SetLongArrayRegion(jhandle, 0, 1, &out);
}
// global JVM
static JavaVM* global_jvm = nullptr;
JavaVM*& GlobalJvm() {
static JavaVM* vm;
return vm;
}
// overrides JNI on load
jint JNI_OnLoad(JavaVM *vm, void *reserved) {
global_jvm = vm;
GlobalJvm() = vm;
return JNI_VERSION_1_6;
}
@@ -58,9 +61,9 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
DataHolderHandle set_function_handle) {
jobject jiter = static_cast<jobject>(data_handle);
JNIEnv* jenv;
int jni_status = global_jvm->GetEnv((void **)&jenv, JNI_VERSION_1_6);
int jni_status = GlobalJvm()->GetEnv((void **)&jenv, JNI_VERSION_1_6);
if (jni_status == JNI_EDETACHED) {
global_jvm->AttachCurrentThread(reinterpret_cast<void **>(&jenv), nullptr);
GlobalJvm()->AttachCurrentThread(reinterpret_cast<void **>(&jenv), nullptr);
} else {
CHECK(jni_status == JNI_OK);
}
@@ -148,13 +151,13 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
jenv->DeleteLocalRef(iterClass);
// only detach if it is a async call.
if (jni_status == JNI_EDETACHED) {
global_jvm->DetachCurrentThread();
GlobalJvm()->DetachCurrentThread();
}
return ret_value;
} catch(dmlc::Error const& e) {
// only detach if it is a async call.
if (jni_status == JNI_EDETACHED) {
global_jvm->DetachCurrentThread();
GlobalJvm()->DetachCurrentThread();
}
LOG(FATAL) << e.what();
return -1;
@@ -968,3 +971,71 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
return 0;
}
namespace xgboost {
namespace jni {
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
jobject jiter,
jfloat jmissing,
jint jmax_bin, jint jnthread,
jlongArray jout);
} // namespace jni
} // namespace xgboost
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDeviceQuantileDMatrixCreateFromCallback
* Signature: (Ljava/util/Iterator;FII[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDMatrixCreateFromCallback
(JNIEnv *jenv, jclass jcls, jobject jiter, jfloat jmissing, jint jmax_bin,
jint jnthread, jlongArray jout) {
return xgboost::jni::XGDeviceQuantileDMatrixCreateFromCallbackImpl(
jenv, jcls, jiter, jmissing, jmax_bin, jnthread, jout);
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixSetInfoFromInterface
* Signature: (JLjava/lang/String;Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFromInterface
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jstring jjson_columns) {
DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* field = jenv->GetStringUTFChars(jfield, 0);
const char* cjson_columns = jenv->GetStringUTFChars(jjson_columns, 0);
int ret = XGDMatrixSetInfoFromInterface(handle, field, cjson_columns);
JVM_CHECK_CALL(ret);
//release
if (field) jenv->ReleaseStringUTFChars(jfield, field);
if (cjson_columns) jenv->ReleaseStringUTFChars(jjson_columns, cjson_columns);
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixCreateFromArrayInterfaceColumns
* Signature: (Ljava/lang/String;FI[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromArrayInterfaceColumns
(JNIEnv *jenv, jclass jcls, jstring jjson_columns, jfloat jmissing, jint jnthread, jlongArray jout) {
DMatrixHandle result;
const char* cjson_columns = jenv->GetStringUTFChars(jjson_columns, nullptr);
xgboost::Json config{xgboost::Object{}};
auto missing = static_cast<float>(jmissing);
auto n_threads = static_cast<int32_t>(jnthread);
config["missing"] = xgboost::Number(missing);
config["nthread"] = xgboost::Integer(n_threads);
std::string config_str;
xgboost::Json::Dump(config, &config_str);
int ret = XGDMatrixCreateFromCudaColumnar(cjson_columns, config_str.c_str(),
&result);
JVM_CHECK_CALL(ret);
if (cjson_columns) {
jenv->ReleaseStringUTFChars(jjson_columns, cjson_columns);
}
setHandle(jenv, jout, result);
return ret;
}

View File

@@ -335,6 +335,30 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
(JNIEnv *, jclass, jobject, jint, jint, jint);
/*
* Class: ml_dmlc_xgboost4j_java_GpuXGBoostJNI
* Method: XGDMatrixSetInfoFromInterface
* Signature: (JLjava/lang/String;Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFromInterface
(JNIEnv *, jclass, jlong, jstring, jstring);
/*
* Class: ml_dmlc_xgboost4j_java_GpuXGBoostJNI
* Method: XGDeviceQuantileDMatrixCreateFromCallback
* Signature: (Ljava/util/Iterator;FII[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDMatrixCreateFromCallback
(JNIEnv *, jclass, jobject, jfloat, jint, jint, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_java_GpuXGBoostJNI
* Method: XGDMatrixCreateFromArrayInterfaceColumns
* Signature: (Ljava/lang/String;FI[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromArrayInterfaceColumns
(JNIEnv *, jclass, jstring, jfloat, jint, jlongArray);
#ifdef __cplusplus
}
#endif