[jvm-packages][xgboost4j-gpu] Support GPU dataframe and DeviceQuantileDMatrix (#7195)
Following classes are added to support dataframe in java binding: - `Column` is an abstract type for a single column in tabular data. - `ColumnBatch` is an abstract type for dataframe. - `CuDFColumn` is an implementaiton of `Column` that consume cuDF column - `CudfColumnBatch` is an implementation of `ColumnBatch` that consumes cuDF dataframe. - `DeviceQuantileDMatrix` is the interface for quantized data. The Java implementation mimics the Python interface and uses `__cuda_array_interface__` protocol for memory indexing. One difference is on JVM package, the data batch is staged on the host as java iterators cannot be reset. Co-authored-by: jiamingy <jm.yuan@outlook.com>
This commit is contained in:
@@ -1 +0,0 @@
|
||||
../../../xgboost4j/src/main/java/
|
||||
@@ -0,0 +1,110 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.gpu.java;
|
||||
|
||||
import ai.rapids.cudf.BaseDeviceMemoryBuffer;
|
||||
import ai.rapids.cudf.BufferType;
|
||||
import ai.rapids.cudf.ColumnVector;
|
||||
import ai.rapids.cudf.DType;
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Column;
|
||||
|
||||
/**
|
||||
* This class is composing of base data with Apache Arrow format from Cudf ColumnVector.
|
||||
* It will be used to generate the cuda array interface.
|
||||
*/
|
||||
class CudfColumn extends Column {
|
||||
|
||||
private final long dataPtr; // gpu data buffer address
|
||||
private final long shape; // row count
|
||||
private final long validPtr; // gpu valid buffer address
|
||||
private final int typeSize; // type size in bytes
|
||||
private final String typeStr; // follow array interface spec
|
||||
private final long nullCount; // null count
|
||||
|
||||
private String arrayInterface = null; // the cuda array interface
|
||||
|
||||
public static CudfColumn from(ColumnVector cv) {
|
||||
BaseDeviceMemoryBuffer dataBuffer = cv.getDeviceBufferFor(BufferType.DATA);
|
||||
BaseDeviceMemoryBuffer validBuffer = cv.getDeviceBufferFor(BufferType.VALIDITY);
|
||||
long validPtr = 0;
|
||||
if (validBuffer != null) {
|
||||
validPtr = validBuffer.getAddress();
|
||||
}
|
||||
DType dType = cv.getType();
|
||||
String typeStr = "";
|
||||
if (dType == DType.FLOAT32 || dType == DType.FLOAT64 ||
|
||||
dType == DType.TIMESTAMP_DAYS || dType == DType.TIMESTAMP_MICROSECONDS ||
|
||||
dType == DType.TIMESTAMP_MILLISECONDS || dType == DType.TIMESTAMP_NANOSECONDS ||
|
||||
dType == DType.TIMESTAMP_SECONDS) {
|
||||
typeStr = "<f" + dType.getSizeInBytes();
|
||||
} else if (dType == DType.BOOL8 || dType == DType.INT8 || dType == DType.INT16 ||
|
||||
dType == DType.INT32 || dType == DType.INT64) {
|
||||
typeStr = "<i" + dType.getSizeInBytes();
|
||||
} else {
|
||||
// Unsupported type.
|
||||
throw new IllegalArgumentException("Unsupported data type: " + dType);
|
||||
}
|
||||
|
||||
return new CudfColumn(dataBuffer.getAddress(), cv.getRowCount(), validPtr,
|
||||
dType.getSizeInBytes(), typeStr, cv.getNullCount());
|
||||
}
|
||||
|
||||
private CudfColumn(long dataPtr, long shape, long validPtr, int typeSize, String typeStr,
|
||||
long nullCount) {
|
||||
this.dataPtr = dataPtr;
|
||||
this.shape = shape;
|
||||
this.validPtr = validPtr;
|
||||
this.typeSize = typeSize;
|
||||
this.typeStr = typeStr;
|
||||
this.nullCount = nullCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getArrayInterfaceJson() {
|
||||
// There is no race-condition
|
||||
if (arrayInterface == null) {
|
||||
arrayInterface = CudfUtils.buildArrayInterface(this);
|
||||
}
|
||||
return arrayInterface;
|
||||
}
|
||||
|
||||
public long getDataPtr() {
|
||||
return dataPtr;
|
||||
}
|
||||
|
||||
public long getShape() {
|
||||
return shape;
|
||||
}
|
||||
|
||||
public long getValidPtr() {
|
||||
return validPtr;
|
||||
}
|
||||
|
||||
public int getTypeSize() {
|
||||
return typeSize;
|
||||
}
|
||||
|
||||
public String getTypeStr() {
|
||||
return typeStr;
|
||||
}
|
||||
|
||||
public long getNullCount() {
|
||||
return nullCount;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.gpu.java;
|
||||
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import ai.rapids.cudf.Table;
|
||||
|
||||
import ml.dmlc.xgboost4j.java.ColumnBatch;
|
||||
|
||||
/**
|
||||
* Class to wrap CUDF Table to generate the cuda array interface.
|
||||
*/
|
||||
public class CudfColumnBatch extends ColumnBatch {
|
||||
private final Table feature;
|
||||
private final Table label;
|
||||
private final Table weight;
|
||||
private final Table baseMargin;
|
||||
|
||||
public CudfColumnBatch(Table feature, Table labels, Table weights, Table baseMargins) {
|
||||
this.feature = feature;
|
||||
this.label = labels;
|
||||
this.weight = weights;
|
||||
this.baseMargin = baseMargins;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getFeatureArrayInterface() {
|
||||
return getArrayInterface(this.feature);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getLabelsArrayInterface() {
|
||||
return getArrayInterface(this.label);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWeightsArrayInterface() {
|
||||
return getArrayInterface(this.weight);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getBaseMarginsArrayInterface() {
|
||||
return getArrayInterface(this.baseMargin);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
if (feature != null) feature.close();
|
||||
if (label != null) label.close();
|
||||
if (weight != null) weight.close();
|
||||
if (baseMargin != null) baseMargin.close();
|
||||
}
|
||||
|
||||
private String getArrayInterface(Table table) {
|
||||
if (table == null || table.getNumberOfColumns() == 0) {
|
||||
return "";
|
||||
}
|
||||
return CudfUtils.buildArrayInterface(getAsCudfColumn(table));
|
||||
}
|
||||
|
||||
private CudfColumn[] getAsCudfColumn(Table table) {
|
||||
if (table == null || table.getNumberOfColumns() == 0) {
|
||||
// This will never happen.
|
||||
return new CudfColumn[]{};
|
||||
}
|
||||
|
||||
return IntStream.range(0, table.getNumberOfColumns())
|
||||
.mapToObj((i) -> table.getColumn(i))
|
||||
.map(CudfColumn::from)
|
||||
.toArray(CudfColumn[]::new);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,100 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.gpu.java;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.IOException;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonFactory;
|
||||
import com.fasterxml.jackson.core.JsonGenerator;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.fasterxml.jackson.databind.node.ArrayNode;
|
||||
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||
|
||||
/**
|
||||
* Cudf utilities to build cuda array interface against {@link CudfColumn}
|
||||
*/
|
||||
class CudfUtils {
|
||||
|
||||
/**
|
||||
* Build the cuda array interface based on CudfColumn(s)
|
||||
* @param cudfColumns the CudfColumn(s) to be built
|
||||
* @return the json format of cuda array interface
|
||||
*/
|
||||
public static String buildArrayInterface(CudfColumn... cudfColumns) {
|
||||
return new Builder().add(cudfColumns).build();
|
||||
}
|
||||
|
||||
// Helper class to build array interface string
|
||||
private static class Builder {
|
||||
private JsonNodeFactory nodeFactory = new JsonNodeFactory(false);
|
||||
private ArrayNode rootArrayNode = nodeFactory.arrayNode();
|
||||
|
||||
private Builder add(CudfColumn... columns) {
|
||||
if (columns == null || columns.length <= 0) {
|
||||
throw new IllegalArgumentException("At least one ColumnData is required.");
|
||||
}
|
||||
for (CudfColumn cd : columns) {
|
||||
rootArrayNode.add(buildColumnObject(cd));
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
private String build() {
|
||||
try {
|
||||
ByteArrayOutputStream bos = new ByteArrayOutputStream();
|
||||
JsonGenerator jsonGen = new JsonFactory().createGenerator(bos);
|
||||
new ObjectMapper().writeTree(jsonGen, rootArrayNode);
|
||||
return bos.toString();
|
||||
} catch (IOException ie) {
|
||||
ie.printStackTrace();
|
||||
throw new RuntimeException("Failed to build array interface. Error: " + ie);
|
||||
}
|
||||
}
|
||||
|
||||
private ObjectNode buildColumnObject(CudfColumn column) {
|
||||
if (column.getDataPtr() == 0) {
|
||||
throw new IllegalArgumentException("Empty column data is NOT accepted!");
|
||||
}
|
||||
if (column.getTypeStr() == null || column.getTypeStr().isEmpty()) {
|
||||
throw new IllegalArgumentException("Empty type string is NOT accepted!");
|
||||
}
|
||||
ObjectNode colDataObj = buildMetaObject(column.getDataPtr(), column.getShape(),
|
||||
column.getTypeStr());
|
||||
|
||||
if (column.getValidPtr() != 0 && column.getNullCount() != 0) {
|
||||
ObjectNode validObj = buildMetaObject(column.getValidPtr(), column.getShape(), "<t1");
|
||||
colDataObj.set("mask", validObj);
|
||||
}
|
||||
return colDataObj;
|
||||
}
|
||||
|
||||
private ObjectNode buildMetaObject(long ptr, long shape, final String typeStr) {
|
||||
ObjectNode objNode = nodeFactory.objectNode();
|
||||
ArrayNode shapeNode = objNode.putArray("shape");
|
||||
shapeNode.add(shape);
|
||||
ArrayNode dataNode = objNode.putArray("data");
|
||||
dataNode.add(ptr)
|
||||
.add(false);
|
||||
objNode.put("typestr", typeStr)
|
||||
.put("version", 1);
|
||||
return objNode;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
1
jvm-packages/xgboost4j-gpu/src/main/java/ml/dmlc/xgboost4j/java
Symbolic link
1
jvm-packages/xgboost4j-gpu/src/main/java/ml/dmlc/xgboost4j/java
Symbolic link
@@ -0,0 +1 @@
|
||||
../../../../../../../xgboost4j/src/main/java/ml/dmlc/xgboost4j/java
|
||||
Reference in New Issue
Block a user