[jvm-packages][xgboost4j-gpu] Support GPU dataframe and DeviceQuantileDMatrix (#7195)
Following classes are added to support dataframe in java binding: - `Column` is an abstract type for a single column in tabular data. - `ColumnBatch` is an abstract type for dataframe. - `CuDFColumn` is an implementaiton of `Column` that consume cuDF column - `CudfColumnBatch` is an implementation of `ColumnBatch` that consumes cuDF dataframe. - `DeviceQuantileDMatrix` is the interface for quantized data. The Java implementation mimics the Python interface and uses `__cuda_array_interface__` protocol for memory indexing. One difference is on JVM package, the data batch is staged on the host as java iterators cannot be reset. Co-authored-by: jiamingy <jm.yuan@outlook.com>
This commit is contained in:
@@ -0,0 +1,40 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
/**
|
||||
* The abstracted XGBoost Column to get the cuda array interface which is used to
|
||||
* set the information for DMatrix.
|
||||
*
|
||||
*/
|
||||
public abstract class Column implements AutoCloseable {
|
||||
|
||||
/**
|
||||
* Get the cuda array interface json string for the Column which can be representing
|
||||
* weight, label, base margin column.
|
||||
*
|
||||
* This API will be called by
|
||||
* {@link DMatrix#setLabel(Column)}
|
||||
* {@link DMatrix#setWeight(Column)}
|
||||
* {@link DMatrix#setBaseMargin(Column)}
|
||||
*/
|
||||
public abstract String getArrayInterfaceJson();
|
||||
|
||||
@Override
|
||||
public void close() throws Exception {}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.util.Iterator;
|
||||
|
||||
/**
|
||||
* The abstracted XGBoost ColumnBatch to get array interface from columnar data format.
|
||||
* For example, the cuDF dataframe which employs apache arrow specification.
|
||||
*/
|
||||
public abstract class ColumnBatch implements AutoCloseable {
|
||||
/**
|
||||
* Get the cuda array interface json string for the whole ColumnBatch including
|
||||
* the must-have feature, label columns and the optional weight, base margin columns.
|
||||
*
|
||||
* This function is be called by native code during iteration and can be made as private
|
||||
* method. We keep it as public simply to silent the linter.
|
||||
*/
|
||||
public final String getArrayInterfaceJson() {
|
||||
|
||||
StringBuilder builder = new StringBuilder();
|
||||
builder.append("{");
|
||||
String featureStr = this.getFeatureArrayInterface();
|
||||
if (featureStr == null || featureStr.isEmpty()) {
|
||||
throw new RuntimeException("Feature array interface must not be empty");
|
||||
} else {
|
||||
builder.append("\"features_str\":" + featureStr);
|
||||
}
|
||||
|
||||
String labelStr = this.getLabelsArrayInterface();
|
||||
if (labelStr == null || labelStr.isEmpty()) {
|
||||
throw new RuntimeException("Label array interface must not be empty");
|
||||
} else {
|
||||
builder.append(",\"label_str\":" + labelStr);
|
||||
}
|
||||
|
||||
String weightStr = getWeightsArrayInterface();
|
||||
if (weightStr != null && ! weightStr.isEmpty()) {
|
||||
builder.append(",\"weight_str\":" + weightStr);
|
||||
}
|
||||
|
||||
String baseMarginStr = getBaseMarginsArrayInterface();
|
||||
if (baseMarginStr != null && ! baseMarginStr.isEmpty()) {
|
||||
builder.append(",\"basemargin_str\":" + baseMarginStr);
|
||||
}
|
||||
|
||||
builder.append("}");
|
||||
return builder.toString();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the cuda array interface of the feature columns.
|
||||
* The returned value must not be null or empty
|
||||
*/
|
||||
public abstract String getFeatureArrayInterface();
|
||||
|
||||
/**
|
||||
* Get the cuda array interface of the label columns.
|
||||
* The returned value must not be null or empty if we're creating
|
||||
* {@link DeviceQuantileDMatrix#DeviceQuantileDMatrix(Iterator, float, int, int)}
|
||||
*/
|
||||
public abstract String getLabelsArrayInterface();
|
||||
|
||||
/**
|
||||
* Get the cuda array interface of the weight columns.
|
||||
* The returned value can be null or empty
|
||||
*/
|
||||
public abstract String getWeightsArrayInterface();
|
||||
|
||||
/**
|
||||
* Get the cuda array interface of the base margin columns.
|
||||
* The returned value can be null or empty
|
||||
*/
|
||||
public abstract String getBaseMarginsArrayInterface();
|
||||
|
||||
@Override
|
||||
public void close() throws Exception {}
|
||||
|
||||
}
|
||||
@@ -177,6 +177,64 @@ public class DMatrix {
|
||||
this.handle = handle;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create the normal DMatrix from column array interface
|
||||
* @param columnBatch the XGBoost ColumnBatch to provide the cuda array interface
|
||||
* of feature columns
|
||||
* @param missing missing value
|
||||
* @param nthread threads number
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public DMatrix(ColumnBatch columnBatch, float missing, int nthread) throws XGBoostError {
|
||||
long[] out = new long[1];
|
||||
String json = columnBatch.getFeatureArrayInterface();
|
||||
if (json == null || json.isEmpty()) {
|
||||
throw new XGBoostError("Expecting non-empty feature columns' array interface");
|
||||
}
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromArrayInterfaceColumns(
|
||||
json, missing, nthread, out));
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Set label of DMatrix from cuda array interface
|
||||
*
|
||||
* @param column the XGBoost Column to provide the cuda array interface
|
||||
* of label column
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setLabel(Column column) throws XGBoostError {
|
||||
setXGBDMatrixInfo("label", column.getArrayInterfaceJson());
|
||||
}
|
||||
|
||||
/**
|
||||
* Set weight of DMatrix from cuda array interface
|
||||
*
|
||||
* @param column the XGBoost Column to provide the cuda array interface
|
||||
* of weight column
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setWeight(Column column) throws XGBoostError {
|
||||
setXGBDMatrixInfo("weight", column.getArrayInterfaceJson());
|
||||
}
|
||||
|
||||
/**
|
||||
* Set base margin of DMatrix from cuda array interface
|
||||
*
|
||||
* @param column the XGBoost Column to provide the cuda array interface
|
||||
* of base margin column
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setBaseMargin(Column column) throws XGBoostError {
|
||||
setXGBDMatrixInfo("base_margin", column.getArrayInterfaceJson());
|
||||
}
|
||||
|
||||
private void setXGBDMatrixInfo(String type, String json) throws XGBoostError {
|
||||
if (json == null || json.isEmpty()) {
|
||||
throw new XGBoostError("Empty " + type + " columns' array interface");
|
||||
}
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetInfoFromInterface(handle, type, json));
|
||||
}
|
||||
|
||||
/**
|
||||
* set label of dmatrix
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.util.Iterator;
|
||||
|
||||
/**
|
||||
* DeviceQuantileDMatrix will only be used to train
|
||||
*/
|
||||
public class DeviceQuantileDMatrix extends DMatrix {
|
||||
/**
|
||||
* Create DeviceQuantileDMatrix from iterator based on the cuda array interface
|
||||
* @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array interface
|
||||
* @param missing the missing value
|
||||
* @param maxBin the max bin
|
||||
* @param nthread the parallelism
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public DeviceQuantileDMatrix(
|
||||
Iterator<ColumnBatch> iter,
|
||||
float missing,
|
||||
int maxBin,
|
||||
int nthread) throws XGBoostError {
|
||||
super(0);
|
||||
long[] out = new long[1];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDeviceQuantileDMatrixCreateFromCallback(
|
||||
iter, missing, maxBin, nthread, out));
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setLabel(Column column) throws XGBoostError {
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setLabel.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setWeight(Column column) throws XGBoostError {
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setWeight.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setBaseMargin(Column column) throws XGBoostError {
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setLabel(float[] labels) throws XGBoostError {
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setLabel.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setWeight(float[] weights) throws XGBoostError {
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setWeight.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setBaseMargin(float[] baseMargin) throws XGBoostError {
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setBaseMargin(float[][] baseMargin) throws XGBoostError {
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setGroup(int[] group) throws XGBoostError {
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setGroup.");
|
||||
}
|
||||
}
|
||||
@@ -134,4 +134,13 @@ class XGBoostJNI {
|
||||
// This JNI function does not support the callback function for data preparation yet.
|
||||
final static native int RabitAllreduce(ByteBuffer sendrecvbuf, int count,
|
||||
int enum_dtype, int enum_op);
|
||||
|
||||
public final static native int XGDMatrixSetInfoFromInterface(
|
||||
long handle, String field, String json);
|
||||
|
||||
public final static native int XGDeviceQuantileDMatrixCreateFromCallback(
|
||||
java.util.Iterator<ColumnBatch> iter, float missing, int nthread, int maxBin, long[] out);
|
||||
|
||||
public final static native int XGDMatrixCreateFromArrayInterfaceColumns(
|
||||
String featureJson, float missing, int nthread, long[] out);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user