[jvm-packages][xgboost4j-gpu] Support GPU dataframe and DeviceQuantileDMatrix (#7195)
Following classes are added to support dataframe in java binding: - `Column` is an abstract type for a single column in tabular data. - `ColumnBatch` is an abstract type for dataframe. - `CuDFColumn` is an implementaiton of `Column` that consume cuDF column - `CudfColumnBatch` is an implementation of `ColumnBatch` that consumes cuDF dataframe. - `DeviceQuantileDMatrix` is the interface for quantized data. The Java implementation mimics the Python interface and uses `__cuda_array_interface__` protocol for memory indexing. One difference is on JVM package, the data batch is staged on the host as java iterators cannot be reset. Co-authored-by: jiamingy <jm.yuan@outlook.com>
This commit is contained in:
@@ -0,0 +1,40 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
/**
|
||||
* The abstracted XGBoost Column to get the cuda array interface which is used to
|
||||
* set the information for DMatrix.
|
||||
*
|
||||
*/
|
||||
public abstract class Column implements AutoCloseable {
|
||||
|
||||
/**
|
||||
* Get the cuda array interface json string for the Column which can be representing
|
||||
* weight, label, base margin column.
|
||||
*
|
||||
* This API will be called by
|
||||
* {@link DMatrix#setLabel(Column)}
|
||||
* {@link DMatrix#setWeight(Column)}
|
||||
* {@link DMatrix#setBaseMargin(Column)}
|
||||
*/
|
||||
public abstract String getArrayInterfaceJson();
|
||||
|
||||
@Override
|
||||
public void close() throws Exception {}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.util.Iterator;
|
||||
|
||||
/**
|
||||
* The abstracted XGBoost ColumnBatch to get array interface from columnar data format.
|
||||
* For example, the cuDF dataframe which employs apache arrow specification.
|
||||
*/
|
||||
public abstract class ColumnBatch implements AutoCloseable {
|
||||
/**
|
||||
* Get the cuda array interface json string for the whole ColumnBatch including
|
||||
* the must-have feature, label columns and the optional weight, base margin columns.
|
||||
*
|
||||
* This function is be called by native code during iteration and can be made as private
|
||||
* method. We keep it as public simply to silent the linter.
|
||||
*/
|
||||
public final String getArrayInterfaceJson() {
|
||||
|
||||
StringBuilder builder = new StringBuilder();
|
||||
builder.append("{");
|
||||
String featureStr = this.getFeatureArrayInterface();
|
||||
if (featureStr == null || featureStr.isEmpty()) {
|
||||
throw new RuntimeException("Feature array interface must not be empty");
|
||||
} else {
|
||||
builder.append("\"features_str\":" + featureStr);
|
||||
}
|
||||
|
||||
String labelStr = this.getLabelsArrayInterface();
|
||||
if (labelStr == null || labelStr.isEmpty()) {
|
||||
throw new RuntimeException("Label array interface must not be empty");
|
||||
} else {
|
||||
builder.append(",\"label_str\":" + labelStr);
|
||||
}
|
||||
|
||||
String weightStr = getWeightsArrayInterface();
|
||||
if (weightStr != null && ! weightStr.isEmpty()) {
|
||||
builder.append(",\"weight_str\":" + weightStr);
|
||||
}
|
||||
|
||||
String baseMarginStr = getBaseMarginsArrayInterface();
|
||||
if (baseMarginStr != null && ! baseMarginStr.isEmpty()) {
|
||||
builder.append(",\"basemargin_str\":" + baseMarginStr);
|
||||
}
|
||||
|
||||
builder.append("}");
|
||||
return builder.toString();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the cuda array interface of the feature columns.
|
||||
* The returned value must not be null or empty
|
||||
*/
|
||||
public abstract String getFeatureArrayInterface();
|
||||
|
||||
/**
|
||||
* Get the cuda array interface of the label columns.
|
||||
* The returned value must not be null or empty if we're creating
|
||||
* {@link DeviceQuantileDMatrix#DeviceQuantileDMatrix(Iterator, float, int, int)}
|
||||
*/
|
||||
public abstract String getLabelsArrayInterface();
|
||||
|
||||
/**
|
||||
* Get the cuda array interface of the weight columns.
|
||||
* The returned value can be null or empty
|
||||
*/
|
||||
public abstract String getWeightsArrayInterface();
|
||||
|
||||
/**
|
||||
* Get the cuda array interface of the base margin columns.
|
||||
* The returned value can be null or empty
|
||||
*/
|
||||
public abstract String getBaseMarginsArrayInterface();
|
||||
|
||||
@Override
|
||||
public void close() throws Exception {}
|
||||
|
||||
}
|
||||
@@ -177,6 +177,64 @@ public class DMatrix {
|
||||
this.handle = handle;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create the normal DMatrix from column array interface
|
||||
* @param columnBatch the XGBoost ColumnBatch to provide the cuda array interface
|
||||
* of feature columns
|
||||
* @param missing missing value
|
||||
* @param nthread threads number
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public DMatrix(ColumnBatch columnBatch, float missing, int nthread) throws XGBoostError {
|
||||
long[] out = new long[1];
|
||||
String json = columnBatch.getFeatureArrayInterface();
|
||||
if (json == null || json.isEmpty()) {
|
||||
throw new XGBoostError("Expecting non-empty feature columns' array interface");
|
||||
}
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromArrayInterfaceColumns(
|
||||
json, missing, nthread, out));
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Set label of DMatrix from cuda array interface
|
||||
*
|
||||
* @param column the XGBoost Column to provide the cuda array interface
|
||||
* of label column
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setLabel(Column column) throws XGBoostError {
|
||||
setXGBDMatrixInfo("label", column.getArrayInterfaceJson());
|
||||
}
|
||||
|
||||
/**
|
||||
* Set weight of DMatrix from cuda array interface
|
||||
*
|
||||
* @param column the XGBoost Column to provide the cuda array interface
|
||||
* of weight column
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setWeight(Column column) throws XGBoostError {
|
||||
setXGBDMatrixInfo("weight", column.getArrayInterfaceJson());
|
||||
}
|
||||
|
||||
/**
|
||||
* Set base margin of DMatrix from cuda array interface
|
||||
*
|
||||
* @param column the XGBoost Column to provide the cuda array interface
|
||||
* of base margin column
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setBaseMargin(Column column) throws XGBoostError {
|
||||
setXGBDMatrixInfo("base_margin", column.getArrayInterfaceJson());
|
||||
}
|
||||
|
||||
private void setXGBDMatrixInfo(String type, String json) throws XGBoostError {
|
||||
if (json == null || json.isEmpty()) {
|
||||
throw new XGBoostError("Empty " + type + " columns' array interface");
|
||||
}
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetInfoFromInterface(handle, type, json));
|
||||
}
|
||||
|
||||
/**
|
||||
* set label of dmatrix
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.util.Iterator;
|
||||
|
||||
/**
|
||||
* DeviceQuantileDMatrix will only be used to train
|
||||
*/
|
||||
public class DeviceQuantileDMatrix extends DMatrix {
|
||||
/**
|
||||
* Create DeviceQuantileDMatrix from iterator based on the cuda array interface
|
||||
* @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array interface
|
||||
* @param missing the missing value
|
||||
* @param maxBin the max bin
|
||||
* @param nthread the parallelism
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public DeviceQuantileDMatrix(
|
||||
Iterator<ColumnBatch> iter,
|
||||
float missing,
|
||||
int maxBin,
|
||||
int nthread) throws XGBoostError {
|
||||
super(0);
|
||||
long[] out = new long[1];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDeviceQuantileDMatrixCreateFromCallback(
|
||||
iter, missing, maxBin, nthread, out));
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setLabel(Column column) throws XGBoostError {
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setLabel.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setWeight(Column column) throws XGBoostError {
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setWeight.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setBaseMargin(Column column) throws XGBoostError {
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setLabel(float[] labels) throws XGBoostError {
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setLabel.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setWeight(float[] weights) throws XGBoostError {
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setWeight.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setBaseMargin(float[] baseMargin) throws XGBoostError {
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setBaseMargin(float[][] baseMargin) throws XGBoostError {
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setGroup(int[] group) throws XGBoostError {
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setGroup.");
|
||||
}
|
||||
}
|
||||
@@ -134,4 +134,13 @@ class XGBoostJNI {
|
||||
// This JNI function does not support the callback function for data preparation yet.
|
||||
final static native int RabitAllreduce(ByteBuffer sendrecvbuf, int count,
|
||||
int enum_dtype, int enum_op);
|
||||
|
||||
public final static native int XGDMatrixSetInfoFromInterface(
|
||||
long handle, String field, String json);
|
||||
|
||||
public final static native int XGDeviceQuantileDMatrixCreateFromCallback(
|
||||
java.util.Iterator<ColumnBatch> iter, float missing, int nthread, int maxBin, long[] out);
|
||||
|
||||
public final static native int XGDMatrixCreateFromArrayInterfaceColumns(
|
||||
String featureJson, float missing, int nthread, long[] out);
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#include <xgboost/c_api.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/json.h>
|
||||
#include "./xgboost4j.h"
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
@@ -43,12 +44,14 @@ void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) {
|
||||
jenv->SetLongArrayRegion(jhandle, 0, 1, &out);
|
||||
}
|
||||
|
||||
// global JVM
|
||||
static JavaVM* global_jvm = nullptr;
|
||||
JavaVM*& GlobalJvm() {
|
||||
static JavaVM* vm;
|
||||
return vm;
|
||||
}
|
||||
|
||||
// overrides JNI on load
|
||||
jint JNI_OnLoad(JavaVM *vm, void *reserved) {
|
||||
global_jvm = vm;
|
||||
GlobalJvm() = vm;
|
||||
return JNI_VERSION_1_6;
|
||||
}
|
||||
|
||||
@@ -58,9 +61,9 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
|
||||
DataHolderHandle set_function_handle) {
|
||||
jobject jiter = static_cast<jobject>(data_handle);
|
||||
JNIEnv* jenv;
|
||||
int jni_status = global_jvm->GetEnv((void **)&jenv, JNI_VERSION_1_6);
|
||||
int jni_status = GlobalJvm()->GetEnv((void **)&jenv, JNI_VERSION_1_6);
|
||||
if (jni_status == JNI_EDETACHED) {
|
||||
global_jvm->AttachCurrentThread(reinterpret_cast<void **>(&jenv), nullptr);
|
||||
GlobalJvm()->AttachCurrentThread(reinterpret_cast<void **>(&jenv), nullptr);
|
||||
} else {
|
||||
CHECK(jni_status == JNI_OK);
|
||||
}
|
||||
@@ -148,13 +151,13 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
|
||||
jenv->DeleteLocalRef(iterClass);
|
||||
// only detach if it is a async call.
|
||||
if (jni_status == JNI_EDETACHED) {
|
||||
global_jvm->DetachCurrentThread();
|
||||
GlobalJvm()->DetachCurrentThread();
|
||||
}
|
||||
return ret_value;
|
||||
} catch(dmlc::Error const& e) {
|
||||
// only detach if it is a async call.
|
||||
if (jni_status == JNI_EDETACHED) {
|
||||
global_jvm->DetachCurrentThread();
|
||||
GlobalJvm()->DetachCurrentThread();
|
||||
}
|
||||
LOG(FATAL) << e.what();
|
||||
return -1;
|
||||
@@ -968,3 +971,71 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
namespace xgboost {
|
||||
namespace jni {
|
||||
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
|
||||
jobject jiter,
|
||||
jfloat jmissing,
|
||||
jint jmax_bin, jint jnthread,
|
||||
jlongArray jout);
|
||||
} // namespace jni
|
||||
} // namespace xgboost
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDeviceQuantileDMatrixCreateFromCallback
|
||||
* Signature: (Ljava/util/Iterator;FII[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDMatrixCreateFromCallback
|
||||
(JNIEnv *jenv, jclass jcls, jobject jiter, jfloat jmissing, jint jmax_bin,
|
||||
jint jnthread, jlongArray jout) {
|
||||
return xgboost::jni::XGDeviceQuantileDMatrixCreateFromCallbackImpl(
|
||||
jenv, jcls, jiter, jmissing, jmax_bin, jnthread, jout);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixSetInfoFromInterface
|
||||
* Signature: (JLjava/lang/String;Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFromInterface
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jstring jjson_columns) {
|
||||
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
||||
const char* field = jenv->GetStringUTFChars(jfield, 0);
|
||||
const char* cjson_columns = jenv->GetStringUTFChars(jjson_columns, 0);
|
||||
|
||||
int ret = XGDMatrixSetInfoFromInterface(handle, field, cjson_columns);
|
||||
JVM_CHECK_CALL(ret);
|
||||
//release
|
||||
if (field) jenv->ReleaseStringUTFChars(jfield, field);
|
||||
if (cjson_columns) jenv->ReleaseStringUTFChars(jjson_columns, cjson_columns);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixCreateFromArrayInterfaceColumns
|
||||
* Signature: (Ljava/lang/String;FI[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromArrayInterfaceColumns
|
||||
(JNIEnv *jenv, jclass jcls, jstring jjson_columns, jfloat jmissing, jint jnthread, jlongArray jout) {
|
||||
DMatrixHandle result;
|
||||
const char* cjson_columns = jenv->GetStringUTFChars(jjson_columns, nullptr);
|
||||
xgboost::Json config{xgboost::Object{}};
|
||||
auto missing = static_cast<float>(jmissing);
|
||||
auto n_threads = static_cast<int32_t>(jnthread);
|
||||
config["missing"] = xgboost::Number(missing);
|
||||
config["nthread"] = xgboost::Integer(n_threads);
|
||||
std::string config_str;
|
||||
xgboost::Json::Dump(config, &config_str);
|
||||
int ret = XGDMatrixCreateFromCudaColumnar(cjson_columns, config_str.c_str(),
|
||||
&result);
|
||||
JVM_CHECK_CALL(ret);
|
||||
if (cjson_columns) {
|
||||
jenv->ReleaseStringUTFChars(jjson_columns, cjson_columns);
|
||||
}
|
||||
|
||||
setHandle(jenv, jout, result);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@@ -335,6 +335,30 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
|
||||
(JNIEnv *, jclass, jobject, jint, jint, jint);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_GpuXGBoostJNI
|
||||
* Method: XGDMatrixSetInfoFromInterface
|
||||
* Signature: (JLjava/lang/String;Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFromInterface
|
||||
(JNIEnv *, jclass, jlong, jstring, jstring);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_GpuXGBoostJNI
|
||||
* Method: XGDeviceQuantileDMatrixCreateFromCallback
|
||||
* Signature: (Ljava/util/Iterator;FII[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDMatrixCreateFromCallback
|
||||
(JNIEnv *, jclass, jobject, jfloat, jint, jint, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_GpuXGBoostJNI
|
||||
* Method: XGDMatrixCreateFromArrayInterfaceColumns
|
||||
* Signature: (Ljava/lang/String;FI[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromArrayInterfaceColumns
|
||||
(JNIEnv *, jclass, jstring, jfloat, jint, jlongArray);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user