[jvm-packages][xgboost4j-gpu] Support GPU dataframe and DeviceQuantileDMatrix (#7195)
Following classes are added to support dataframe in java binding: - `Column` is an abstract type for a single column in tabular data. - `ColumnBatch` is an abstract type for dataframe. - `CuDFColumn` is an implementaiton of `Column` that consume cuDF column - `CudfColumnBatch` is an implementation of `ColumnBatch` that consumes cuDF dataframe. - `DeviceQuantileDMatrix` is the interface for quantized data. The Java implementation mimics the Python interface and uses `__cuda_array_interface__` protocol for memory indexing. One difference is on JVM package, the data batch is staged on the host as java iterators cannot be reset. Co-authored-by: jiamingy <jm.yuan@outlook.com>
This commit is contained in:
parent
d27a427dc5
commit
0ee11dac77
2
Jenkinsfile
vendored
2
Jenkinsfile
vendored
@ -64,7 +64,7 @@ pipeline {
|
||||
// The build-gpu-* builds below use Ubuntu image
|
||||
'build-gpu-cuda11.0': { BuildCUDA(cuda_version: '11.0', build_rmm: true) },
|
||||
'build-gpu-rpkg': { BuildRPackageWithCUDA(cuda_version: '10.1') },
|
||||
'build-jvm-packages-gpu-cuda10.1': { BuildJVMPackagesWithCUDA(spark_version: '3.0.0', cuda_version: '10.1') },
|
||||
'build-jvm-packages-gpu-cuda10.1': { BuildJVMPackagesWithCUDA(spark_version: '3.0.0', cuda_version: '11.0') },
|
||||
'build-jvm-packages': { BuildJVMPackages(spark_version: '3.0.0') },
|
||||
'build-jvm-doc': { BuildJVMDoc() }
|
||||
])
|
||||
|
||||
@ -1,10 +1,20 @@
|
||||
find_package(JNI REQUIRED)
|
||||
|
||||
add_library(xgboost4j SHARED
|
||||
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j/src/native/xgboost4j.cpp)
|
||||
list(APPEND JVM_SOURCES
|
||||
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j/src/native/xgboost4j.cpp
|
||||
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cpp)
|
||||
|
||||
if (USE_CUDA)
|
||||
list(APPEND JVM_SOURCES
|
||||
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cu)
|
||||
endif (USE_CUDA)
|
||||
|
||||
add_library(xgboost4j SHARED ${JVM_SOURCES} ${XGBOOST_OBJ_SOURCES})
|
||||
|
||||
if (ENABLE_ALL_WARNINGS)
|
||||
target_compile_options(xgboost4j PUBLIC -Wall -Wextra)
|
||||
endif (ENABLE_ALL_WARNINGS)
|
||||
|
||||
target_link_libraries(xgboost4j PRIVATE objxgboost)
|
||||
target_include_directories(xgboost4j
|
||||
PRIVATE
|
||||
@ -15,8 +25,4 @@ target_include_directories(xgboost4j
|
||||
${PROJECT_SOURCE_DIR}/rabit/include)
|
||||
|
||||
set_output_directory(xgboost4j ${PROJECT_SOURCE_DIR}/lib)
|
||||
set_target_properties(
|
||||
xgboost4j PROPERTIES
|
||||
CXX_STANDARD 14
|
||||
CXX_STANDARD_REQUIRED ON)
|
||||
target_link_libraries(xgboost4j PRIVATE ${JAVA_JVM_LIBRARY})
|
||||
|
||||
@ -12,7 +12,24 @@
|
||||
<version>1.5.0-SNAPSHOT</version>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<properties>
|
||||
<cudf.version>21.08.2</cudf.version>
|
||||
<cudf.classifier>cuda11</cudf.classifier>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>ai.rapids</groupId>
|
||||
<artifactId>cudf</artifactId>
|
||||
<version>${cudf.version}</version>
|
||||
<classifier>${cudf.classifier}</classifier>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.fasterxml.jackson.core</groupId>
|
||||
<artifactId>jackson-databind</artifactId>
|
||||
<version>2.10.5.1</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.hadoop</groupId>
|
||||
<artifactId>hadoop-hdfs</artifactId>
|
||||
|
||||
@ -1 +0,0 @@
|
||||
../../../xgboost4j/src/main/java/
|
||||
@ -0,0 +1,110 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.gpu.java;
|
||||
|
||||
import ai.rapids.cudf.BaseDeviceMemoryBuffer;
|
||||
import ai.rapids.cudf.BufferType;
|
||||
import ai.rapids.cudf.ColumnVector;
|
||||
import ai.rapids.cudf.DType;
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Column;
|
||||
|
||||
/**
|
||||
* This class is composing of base data with Apache Arrow format from Cudf ColumnVector.
|
||||
* It will be used to generate the cuda array interface.
|
||||
*/
|
||||
class CudfColumn extends Column {
|
||||
|
||||
private final long dataPtr; // gpu data buffer address
|
||||
private final long shape; // row count
|
||||
private final long validPtr; // gpu valid buffer address
|
||||
private final int typeSize; // type size in bytes
|
||||
private final String typeStr; // follow array interface spec
|
||||
private final long nullCount; // null count
|
||||
|
||||
private String arrayInterface = null; // the cuda array interface
|
||||
|
||||
public static CudfColumn from(ColumnVector cv) {
|
||||
BaseDeviceMemoryBuffer dataBuffer = cv.getDeviceBufferFor(BufferType.DATA);
|
||||
BaseDeviceMemoryBuffer validBuffer = cv.getDeviceBufferFor(BufferType.VALIDITY);
|
||||
long validPtr = 0;
|
||||
if (validBuffer != null) {
|
||||
validPtr = validBuffer.getAddress();
|
||||
}
|
||||
DType dType = cv.getType();
|
||||
String typeStr = "";
|
||||
if (dType == DType.FLOAT32 || dType == DType.FLOAT64 ||
|
||||
dType == DType.TIMESTAMP_DAYS || dType == DType.TIMESTAMP_MICROSECONDS ||
|
||||
dType == DType.TIMESTAMP_MILLISECONDS || dType == DType.TIMESTAMP_NANOSECONDS ||
|
||||
dType == DType.TIMESTAMP_SECONDS) {
|
||||
typeStr = "<f" + dType.getSizeInBytes();
|
||||
} else if (dType == DType.BOOL8 || dType == DType.INT8 || dType == DType.INT16 ||
|
||||
dType == DType.INT32 || dType == DType.INT64) {
|
||||
typeStr = "<i" + dType.getSizeInBytes();
|
||||
} else {
|
||||
// Unsupported type.
|
||||
throw new IllegalArgumentException("Unsupported data type: " + dType);
|
||||
}
|
||||
|
||||
return new CudfColumn(dataBuffer.getAddress(), cv.getRowCount(), validPtr,
|
||||
dType.getSizeInBytes(), typeStr, cv.getNullCount());
|
||||
}
|
||||
|
||||
private CudfColumn(long dataPtr, long shape, long validPtr, int typeSize, String typeStr,
|
||||
long nullCount) {
|
||||
this.dataPtr = dataPtr;
|
||||
this.shape = shape;
|
||||
this.validPtr = validPtr;
|
||||
this.typeSize = typeSize;
|
||||
this.typeStr = typeStr;
|
||||
this.nullCount = nullCount;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getArrayInterfaceJson() {
|
||||
// There is no race-condition
|
||||
if (arrayInterface == null) {
|
||||
arrayInterface = CudfUtils.buildArrayInterface(this);
|
||||
}
|
||||
return arrayInterface;
|
||||
}
|
||||
|
||||
public long getDataPtr() {
|
||||
return dataPtr;
|
||||
}
|
||||
|
||||
public long getShape() {
|
||||
return shape;
|
||||
}
|
||||
|
||||
public long getValidPtr() {
|
||||
return validPtr;
|
||||
}
|
||||
|
||||
public int getTypeSize() {
|
||||
return typeSize;
|
||||
}
|
||||
|
||||
public String getTypeStr() {
|
||||
return typeStr;
|
||||
}
|
||||
|
||||
public long getNullCount() {
|
||||
return nullCount;
|
||||
}
|
||||
|
||||
}
|
||||
@ -0,0 +1,88 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.gpu.java;
|
||||
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import ai.rapids.cudf.Table;
|
||||
|
||||
import ml.dmlc.xgboost4j.java.ColumnBatch;
|
||||
|
||||
/**
|
||||
* Class to wrap CUDF Table to generate the cuda array interface.
|
||||
*/
|
||||
public class CudfColumnBatch extends ColumnBatch {
|
||||
private final Table feature;
|
||||
private final Table label;
|
||||
private final Table weight;
|
||||
private final Table baseMargin;
|
||||
|
||||
public CudfColumnBatch(Table feature, Table labels, Table weights, Table baseMargins) {
|
||||
this.feature = feature;
|
||||
this.label = labels;
|
||||
this.weight = weights;
|
||||
this.baseMargin = baseMargins;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getFeatureArrayInterface() {
|
||||
return getArrayInterface(this.feature);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getLabelsArrayInterface() {
|
||||
return getArrayInterface(this.label);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWeightsArrayInterface() {
|
||||
return getArrayInterface(this.weight);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getBaseMarginsArrayInterface() {
|
||||
return getArrayInterface(this.baseMargin);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
if (feature != null) feature.close();
|
||||
if (label != null) label.close();
|
||||
if (weight != null) weight.close();
|
||||
if (baseMargin != null) baseMargin.close();
|
||||
}
|
||||
|
||||
private String getArrayInterface(Table table) {
|
||||
if (table == null || table.getNumberOfColumns() == 0) {
|
||||
return "";
|
||||
}
|
||||
return CudfUtils.buildArrayInterface(getAsCudfColumn(table));
|
||||
}
|
||||
|
||||
private CudfColumn[] getAsCudfColumn(Table table) {
|
||||
if (table == null || table.getNumberOfColumns() == 0) {
|
||||
// This will never happen.
|
||||
return new CudfColumn[]{};
|
||||
}
|
||||
|
||||
return IntStream.range(0, table.getNumberOfColumns())
|
||||
.mapToObj((i) -> table.getColumn(i))
|
||||
.map(CudfColumn::from)
|
||||
.toArray(CudfColumn[]::new);
|
||||
}
|
||||
|
||||
}
|
||||
@ -0,0 +1,100 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.gpu.java;
|
||||
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.IOException;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonFactory;
|
||||
import com.fasterxml.jackson.core.JsonGenerator;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.fasterxml.jackson.databind.node.ArrayNode;
|
||||
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
|
||||
import com.fasterxml.jackson.databind.node.ObjectNode;
|
||||
|
||||
/**
|
||||
* Cudf utilities to build cuda array interface against {@link CudfColumn}
|
||||
*/
|
||||
class CudfUtils {
|
||||
|
||||
/**
|
||||
* Build the cuda array interface based on CudfColumn(s)
|
||||
* @param cudfColumns the CudfColumn(s) to be built
|
||||
* @return the json format of cuda array interface
|
||||
*/
|
||||
public static String buildArrayInterface(CudfColumn... cudfColumns) {
|
||||
return new Builder().add(cudfColumns).build();
|
||||
}
|
||||
|
||||
// Helper class to build array interface string
|
||||
private static class Builder {
|
||||
private JsonNodeFactory nodeFactory = new JsonNodeFactory(false);
|
||||
private ArrayNode rootArrayNode = nodeFactory.arrayNode();
|
||||
|
||||
private Builder add(CudfColumn... columns) {
|
||||
if (columns == null || columns.length <= 0) {
|
||||
throw new IllegalArgumentException("At least one ColumnData is required.");
|
||||
}
|
||||
for (CudfColumn cd : columns) {
|
||||
rootArrayNode.add(buildColumnObject(cd));
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
private String build() {
|
||||
try {
|
||||
ByteArrayOutputStream bos = new ByteArrayOutputStream();
|
||||
JsonGenerator jsonGen = new JsonFactory().createGenerator(bos);
|
||||
new ObjectMapper().writeTree(jsonGen, rootArrayNode);
|
||||
return bos.toString();
|
||||
} catch (IOException ie) {
|
||||
ie.printStackTrace();
|
||||
throw new RuntimeException("Failed to build array interface. Error: " + ie);
|
||||
}
|
||||
}
|
||||
|
||||
private ObjectNode buildColumnObject(CudfColumn column) {
|
||||
if (column.getDataPtr() == 0) {
|
||||
throw new IllegalArgumentException("Empty column data is NOT accepted!");
|
||||
}
|
||||
if (column.getTypeStr() == null || column.getTypeStr().isEmpty()) {
|
||||
throw new IllegalArgumentException("Empty type string is NOT accepted!");
|
||||
}
|
||||
ObjectNode colDataObj = buildMetaObject(column.getDataPtr(), column.getShape(),
|
||||
column.getTypeStr());
|
||||
|
||||
if (column.getValidPtr() != 0 && column.getNullCount() != 0) {
|
||||
ObjectNode validObj = buildMetaObject(column.getValidPtr(), column.getShape(), "<t1");
|
||||
colDataObj.set("mask", validObj);
|
||||
}
|
||||
return colDataObj;
|
||||
}
|
||||
|
||||
private ObjectNode buildMetaObject(long ptr, long shape, final String typeStr) {
|
||||
ObjectNode objNode = nodeFactory.objectNode();
|
||||
ArrayNode shapeNode = objNode.putArray("shape");
|
||||
shapeNode.add(shape);
|
||||
ArrayNode dataNode = objNode.putArray("data");
|
||||
dataNode.add(ptr)
|
||||
.add(false);
|
||||
objNode.put("typestr", typeStr)
|
||||
.put("version", 1);
|
||||
return objNode;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
1
jvm-packages/xgboost4j-gpu/src/main/java/ml/dmlc/xgboost4j/java
Symbolic link
1
jvm-packages/xgboost4j-gpu/src/main/java/ml/dmlc/xgboost4j/java
Symbolic link
@ -0,0 +1 @@
|
||||
../../../../../../../xgboost4j/src/main/java/ml/dmlc/xgboost4j/java
|
||||
@ -1 +0,0 @@
|
||||
../../xgboost4j/src/native
|
||||
15
jvm-packages/xgboost4j-gpu/src/native/jvm_utils.h
Normal file
15
jvm-packages/xgboost4j-gpu/src/native/jvm_utils.h
Normal file
@ -0,0 +1,15 @@
|
||||
#ifndef JVM_UTILS_H_
|
||||
#define JVM_UTILS_H_
|
||||
|
||||
#define JVM_CHECK_CALL(__expr) \
|
||||
{ \
|
||||
int __errcode = (__expr); \
|
||||
if (__errcode != 0) { \
|
||||
return __errcode; \
|
||||
} \
|
||||
}
|
||||
|
||||
JavaVM*& GlobalJvm();
|
||||
void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle);
|
||||
|
||||
#endif // JVM_UTILS_H_
|
||||
25
jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cpp
Normal file
25
jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cpp
Normal file
@ -0,0 +1,25 @@
|
||||
//
|
||||
// Created by bobwang on 2021/9/8.
|
||||
//
|
||||
|
||||
#ifndef XGBOOST_USE_CUDA
|
||||
|
||||
#include <jni.h>
|
||||
|
||||
#include "../../../../src/common/common.h"
|
||||
#include "../../../../src/c_api/c_api_error.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace jni {
|
||||
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
|
||||
jobject jiter,
|
||||
jfloat jmissing,
|
||||
jint jmax_bin, jint jnthread,
|
||||
jlongArray jout) {
|
||||
API_BEGIN();
|
||||
common::AssertGPUSupport();
|
||||
API_END();
|
||||
}
|
||||
} // namespace jni
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
398
jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cu
Normal file
398
jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cu
Normal file
@ -0,0 +1,398 @@
|
||||
#include <jni.h>
|
||||
#include <thrust/system/cuda/experimental/pinned_allocator.h>
|
||||
|
||||
#include "../../../../src/common/device_helpers.cuh"
|
||||
#include "../../../../src/data/array_interface.h"
|
||||
#include "jvm_utils.h"
|
||||
#include <xgboost/c_api.h>
|
||||
|
||||
namespace xgboost {
|
||||
namespace jni {
|
||||
|
||||
template <typename T, typename Alloc>
|
||||
T const *RawPtr(std::vector<T, Alloc> const &data) {
|
||||
return data.data();
|
||||
}
|
||||
|
||||
template <typename T, typename Alloc> T *RawPtr(std::vector<T, Alloc> &data) {
|
||||
return data.data();
|
||||
}
|
||||
|
||||
template <typename T> T const *RawPtr(dh::device_vector<T> const &data) {
|
||||
return data.data().get();
|
||||
}
|
||||
|
||||
template <typename T> T *RawPtr(dh::device_vector<T> &data) {
|
||||
return data.data().get();
|
||||
}
|
||||
|
||||
template <typename T> T CheckJvmCall(T const &v, JNIEnv *jenv) {
|
||||
if (!v) {
|
||||
CHECK(jenv->ExceptionOccurred());
|
||||
jenv->ExceptionDescribe();
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
template <typename VCont>
|
||||
void CopyColumnMask(xgboost::ArrayInterface const &interface,
|
||||
std::vector<Json> const &columns, cudaMemcpyKind kind,
|
||||
size_t c, VCont *p_mask, Json *p_out, cudaStream_t stream) {
|
||||
auto &mask = *p_mask;
|
||||
auto &out = *p_out;
|
||||
auto size = sizeof(typename VCont::value_type) * interface.num_rows *
|
||||
interface.num_cols;
|
||||
mask.resize(size);
|
||||
CHECK(RawPtr(mask));
|
||||
CHECK(size);
|
||||
CHECK(interface.valid.Data());
|
||||
dh::safe_cuda(
|
||||
cudaMemcpyAsync(RawPtr(mask), interface.valid.Data(), size, kind, stream));
|
||||
auto const &mask_column = columns[c]["mask"];
|
||||
out["mask"] = Object();
|
||||
std::vector<Json> mask_data{
|
||||
Json{reinterpret_cast<Integer::Int>(RawPtr(mask))},
|
||||
Json{get<Boolean const>(mask_column["data"][1])}};
|
||||
out["mask"]["data"] = Array(std::move(mask_data));
|
||||
if (get<Array const>(mask_column["shape"]).size() == 2) {
|
||||
std::vector<Json> mask_shape{
|
||||
Json{get<Integer const>(mask_column["shape"][0])},
|
||||
Json{get<Integer const>(mask_column["shape"][1])}};
|
||||
out["mask"]["shape"] = Array(std::move(mask_shape));
|
||||
} else if (get<Array const>(mask_column["shape"]).size() == 1) {
|
||||
std::vector<Json> mask_shape{
|
||||
Json{get<Integer const>(mask_column["shape"][0])}};
|
||||
out["mask"]["shape"] = Array(std::move(mask_shape));
|
||||
} else {
|
||||
LOG(FATAL) << "Invalid shape of mask";
|
||||
}
|
||||
out["mask"]["typestr"] = String("<t1");
|
||||
out["mask"]["version"] = Integer(1);
|
||||
}
|
||||
|
||||
template <typename DCont, typename VCont>
|
||||
void CopyInterface(std::vector<xgboost::ArrayInterface> &interface_arr,
|
||||
std::vector<Json> const &columns, cudaMemcpyKind kind,
|
||||
std::vector<DCont> *p_data, std::vector<VCont> *p_mask,
|
||||
std::vector<xgboost::Json> *p_out, cudaStream_t stream) {
|
||||
p_data->resize(interface_arr.size());
|
||||
p_mask->resize(interface_arr.size());
|
||||
p_out->resize(interface_arr.size());
|
||||
for (size_t c = 0; c < interface_arr.size(); ++c) {
|
||||
auto &interface = interface_arr.at(c);
|
||||
size_t element_size = interface.ElementSize();
|
||||
size_t size = element_size * interface.num_rows * interface.num_cols;
|
||||
|
||||
auto &data = (*p_data)[c];
|
||||
auto &mask = (*p_mask)[c];
|
||||
data.resize(size);
|
||||
dh::safe_cuda(cudaMemcpyAsync(RawPtr(data), interface.data, size, kind, stream));
|
||||
|
||||
auto &out = (*p_out)[c];
|
||||
out = Object();
|
||||
std::vector<Json> j_data{
|
||||
Json{Integer(reinterpret_cast<Integer::Int>(RawPtr(data)))},
|
||||
Json{Boolean{false}}};
|
||||
|
||||
out["data"] = Array(std::move(j_data));
|
||||
out["shape"] = Array(std::vector<Json>{Json(Integer(interface.num_rows)),
|
||||
Json(Integer(interface.num_cols))});
|
||||
|
||||
if (interface.valid.Data()) {
|
||||
CopyColumnMask(interface, columns, kind, c, &mask, &out, stream);
|
||||
}
|
||||
out["typestr"] = String("<f4");
|
||||
out["version"] = Integer(1);
|
||||
}
|
||||
}
|
||||
|
||||
void CopyMetaInfo(Json *p_interface, dh::device_vector<float> *out, cudaStream_t stream) {
|
||||
auto &j_interface = *p_interface;
|
||||
CHECK_EQ(get<Array const>(j_interface).size(), 1);
|
||||
auto object = get<Object>(get<Array>(j_interface)[0]);
|
||||
ArrayInterface interface(object);
|
||||
out->resize(interface.num_rows);
|
||||
size_t element_size = interface.ElementSize();
|
||||
size_t size = element_size * interface.num_rows;
|
||||
dh::safe_cuda(cudaMemcpyAsync(RawPtr(*out), interface.data, size,
|
||||
cudaMemcpyDeviceToDevice, stream));
|
||||
j_interface[0]["data"][0] = reinterpret_cast<Integer::Int>(RawPtr(*out));
|
||||
}
|
||||
|
||||
template <typename DCont, typename VCont> struct DataFrame {
|
||||
std::vector<DCont> data;
|
||||
std::vector<VCont> valid;
|
||||
std::vector<Json> interfaces;
|
||||
};
|
||||
|
||||
class DataIteratorProxy {
|
||||
DMatrixHandle proxy_;
|
||||
JNIEnv *jenv_;
|
||||
int jni_status_;
|
||||
jobject jiter_;
|
||||
bool cache_on_host_{true}; // TODO(Bobby): Make this optional.
|
||||
|
||||
template <typename T>
|
||||
using Alloc = thrust::system::cuda::experimental::pinned_allocator<T>;
|
||||
template <typename U>
|
||||
using HostVector = std::vector<U, Alloc<U>>;
|
||||
|
||||
// This vector is created for staging device data on host to save GPU memory.
|
||||
// When space is not of concern, we can stage them on device memory directly.
|
||||
std::vector<
|
||||
std::unique_ptr<DataFrame<HostVector<char>, HostVector<std::uint8_t>>>>
|
||||
host_columns_;
|
||||
// TODO(Bobby): Use this instead of `host_columns_` if staging is not
|
||||
// required.
|
||||
std::vector<std::unique_ptr<DataFrame<dh::device_vector<char>,
|
||||
dh::device_vector<std::uint8_t>>>>
|
||||
device_columns_;
|
||||
|
||||
// Staging area for metainfo.
|
||||
// TODO(Bobby): label_upper_bound, label_lower_bound, group.
|
||||
std::vector<std::unique_ptr<dh::device_vector<float>>> labels_;
|
||||
std::vector<std::unique_ptr<dh::device_vector<float>>> weights_;
|
||||
std::vector<std::unique_ptr<dh::device_vector<float>>> base_margins_;
|
||||
std::vector<Json> label_interfaces_;
|
||||
std::vector<Json> weight_interfaces_;
|
||||
std::vector<Json> margin_interfaces_;
|
||||
|
||||
size_t it_{0};
|
||||
size_t n_batches_{0};
|
||||
bool initialized_{false};
|
||||
jobject last_batch_ {nullptr};
|
||||
|
||||
// Temp buffer on device, each `dh::device_vector` represents a column
|
||||
// from cudf.
|
||||
std::vector<dh::device_vector<char>> staging_data_;
|
||||
std::vector<dh::device_vector<uint8_t>> staging_mask_;
|
||||
|
||||
cudaStream_t copy_stream_;
|
||||
|
||||
public:
|
||||
explicit DataIteratorProxy(jobject jiter, bool cache_on_host = true)
|
||||
: jiter_{jiter}, cache_on_host_{cache_on_host} {
|
||||
XGProxyDMatrixCreate(&proxy_);
|
||||
jni_status_ =
|
||||
GlobalJvm()->GetEnv(reinterpret_cast<void **>(&jenv_), JNI_VERSION_1_6);
|
||||
this->Reset();
|
||||
dh::safe_cuda(cudaStreamCreateWithFlags(©_stream_, cudaStreamNonBlocking));
|
||||
}
|
||||
~DataIteratorProxy() { XGDMatrixFree(proxy_);
|
||||
dh::safe_cuda(cudaStreamDestroy(copy_stream_));
|
||||
}
|
||||
|
||||
DMatrixHandle GetDMatrixHandle() const { return proxy_; }
|
||||
|
||||
// Helper function for staging meta info.
|
||||
void StageMetaInfo(Json json_interface) {
|
||||
CHECK(!IsA<Null>(json_interface));
|
||||
auto json_map = get<Object const>(json_interface);
|
||||
if (json_map.find("label_str") == json_map.cend()) {
|
||||
LOG(FATAL) << "Must have a label field.";
|
||||
}
|
||||
|
||||
Json label = json_interface["label_str"];
|
||||
CHECK(!IsA<Null>(label));
|
||||
labels_.emplace_back(new dh::device_vector<float>);
|
||||
CopyMetaInfo(&label, labels_.back().get(), copy_stream_);
|
||||
label_interfaces_.emplace_back(label);
|
||||
|
||||
std::string str;
|
||||
Json::Dump(label, &str);
|
||||
XGDMatrixSetInfoFromInterface(proxy_, "label", str.c_str());
|
||||
|
||||
if (json_map.find("weight_str") != json_map.cend()) {
|
||||
Json weight = json_interface["weight_str"];
|
||||
CHECK(!IsA<Null>(weight));
|
||||
weights_.emplace_back(new dh::device_vector<float>);
|
||||
CopyMetaInfo(&weight, weights_.back().get(), copy_stream_);
|
||||
weight_interfaces_.emplace_back(weight);
|
||||
|
||||
Json::Dump(weight, &str);
|
||||
XGDMatrixSetInfoFromInterface(proxy_, "weight", str.c_str());
|
||||
}
|
||||
|
||||
if (json_map.find("basemargin_str") != json_map.cend()) {
|
||||
Json basemargin = json_interface["basemargin_str"];
|
||||
base_margins_.emplace_back(new dh::device_vector<float>);
|
||||
CopyMetaInfo(&basemargin, base_margins_.back().get(), copy_stream_);
|
||||
margin_interfaces_.emplace_back(basemargin);
|
||||
|
||||
Json::Dump(basemargin, &str);
|
||||
XGDMatrixSetInfoFromInterface(proxy_, "base_margin", str.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
void CloseJvmBatch() {
|
||||
if (last_batch_) {
|
||||
jclass batch_class = CheckJvmCall(jenv_->GetObjectClass(last_batch_), jenv_);
|
||||
jmethodID closeMethod = CheckJvmCall(jenv_->GetMethodID(batch_class, "close", "()V"), jenv_);
|
||||
jenv_->CallVoidMethod(last_batch_, closeMethod);
|
||||
last_batch_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
it_ = 0;
|
||||
this->CloseJvmBatch();
|
||||
}
|
||||
|
||||
int32_t PullIterFromJVM() {
|
||||
jclass iterClass = jenv_->FindClass("java/util/Iterator");
|
||||
this->CloseJvmBatch();
|
||||
|
||||
jmethodID has_next =
|
||||
CheckJvmCall(jenv_->GetMethodID(iterClass, "hasNext", "()Z"), jenv_);
|
||||
jmethodID next = CheckJvmCall(
|
||||
jenv_->GetMethodID(iterClass, "next", "()Ljava/lang/Object;"), jenv_);
|
||||
|
||||
if (jenv_->CallBooleanMethod(jiter_, has_next)) {
|
||||
// batch should be ColumnBatch from jvm
|
||||
jobject batch = CheckJvmCall(jenv_->CallObjectMethod(jiter_, next), jenv_);
|
||||
jclass batch_class = CheckJvmCall(jenv_->GetObjectClass(batch), jenv_);
|
||||
jmethodID getArrayInterfaceJson = CheckJvmCall(jenv_->GetMethodID(
|
||||
batch_class, "getArrayInterfaceJson", "()Ljava/lang/String;"), jenv_);
|
||||
|
||||
auto jinterface =
|
||||
static_cast<jstring>(jenv_->CallObjectMethod(batch, getArrayInterfaceJson));
|
||||
CheckJvmCall(jinterface, jenv_);
|
||||
char const *c_interface_str =
|
||||
CheckJvmCall(jenv_->GetStringUTFChars(jinterface, nullptr), jenv_);
|
||||
|
||||
StageData(c_interface_str);
|
||||
|
||||
jenv_->ReleaseStringUTFChars(jinterface, c_interface_str);
|
||||
|
||||
last_batch_ = batch;
|
||||
return 1;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
void StageData(std::string interface_str) {
|
||||
++n_batches_;
|
||||
// DataFrame
|
||||
using T = decltype(host_columns_)::value_type::element_type;
|
||||
host_columns_.emplace_back(std::unique_ptr<T>(new T));
|
||||
|
||||
// Stage the meta info.
|
||||
auto json_interface =
|
||||
Json::Load({interface_str.c_str(), interface_str.size()});
|
||||
CHECK(!IsA<Null>(json_interface));
|
||||
StageMetaInfo(json_interface);
|
||||
|
||||
Json features = json_interface["features_str"];
|
||||
auto json_columns = get<Array const>(features);
|
||||
std::vector<ArrayInterface> interfaces;
|
||||
|
||||
// Stage the data
|
||||
for (auto &json_col : json_columns) {
|
||||
auto column = ArrayInterface(get<Object const>(json_col));
|
||||
interfaces.emplace_back(column);
|
||||
}
|
||||
Json::Dump(features, &interface_str);
|
||||
CopyInterface(interfaces, json_columns, cudaMemcpyDeviceToHost,
|
||||
&host_columns_.back()->data, &host_columns_.back()->valid,
|
||||
&host_columns_.back()->interfaces, copy_stream_);
|
||||
|
||||
XGProxyDMatrixSetDataCudaColumnar(proxy_, interface_str.c_str());
|
||||
it_++;
|
||||
}
|
||||
|
||||
int NextFirstLoop() {
|
||||
try {
|
||||
dh::safe_cuda(cudaStreamSynchronize(copy_stream_));
|
||||
if (this->PullIterFromJVM()) {
|
||||
return 1;
|
||||
} else {
|
||||
initialized_ = true;
|
||||
return 0;
|
||||
}
|
||||
} catch (dmlc::Error const &e) {
|
||||
if (jni_status_ == JNI_EDETACHED) {
|
||||
GlobalJvm()->DetachCurrentThread();
|
||||
}
|
||||
LOG(FATAL) << e.what();
|
||||
}
|
||||
LOG(FATAL) << "Unreachable";
|
||||
return 1;
|
||||
}
|
||||
|
||||
int NextSecondLoop() {
|
||||
std::string str;
|
||||
// Meta
|
||||
auto const &label = this->label_interfaces_.at(it_);
|
||||
Json::Dump(label, &str);
|
||||
XGDMatrixSetInfoFromInterface(proxy_, "label", str.c_str());
|
||||
|
||||
if (n_batches_ == this->weight_interfaces_.size()) {
|
||||
auto const &weight = this->weight_interfaces_.at(it_);
|
||||
Json::Dump(weight, &str);
|
||||
XGDMatrixSetInfoFromInterface(proxy_, "weight", str.c_str());
|
||||
}
|
||||
|
||||
if (n_batches_ == this->margin_interfaces_.size()) {
|
||||
auto const &base_margin = this->margin_interfaces_.at(it_);
|
||||
Json::Dump(base_margin, &str);
|
||||
XGDMatrixSetInfoFromInterface(proxy_, "base_margin", str.c_str());
|
||||
}
|
||||
|
||||
// Data
|
||||
auto const &json_interface = host_columns_.at(it_)->interfaces;
|
||||
|
||||
std::vector<ArrayInterface> in;
|
||||
for (auto interface : json_interface) {
|
||||
auto column = ArrayInterface(get<Object const>(interface));
|
||||
in.emplace_back(column);
|
||||
}
|
||||
std::vector<Json> out;
|
||||
CopyInterface(in, json_interface, cudaMemcpyHostToDevice, &staging_data_,
|
||||
&staging_mask_, &out, nullptr);
|
||||
|
||||
Json temp{Array(std::move(out))};
|
||||
std::string interface_str;
|
||||
Json::Dump(temp, &interface_str);
|
||||
XGProxyDMatrixSetDataCudaColumnar(proxy_, interface_str.c_str());
|
||||
it_++;
|
||||
return 1;
|
||||
}
|
||||
|
||||
int Next() {
|
||||
if (!initialized_) {
|
||||
return NextFirstLoop();
|
||||
} else {
|
||||
if (it_ == n_batches_) {
|
||||
return 0;
|
||||
}
|
||||
return NextSecondLoop();
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
namespace {
|
||||
void Reset(DataIterHandle self) {
|
||||
static_cast<xgboost::jni::DataIteratorProxy *>(self)->Reset();
|
||||
}
|
||||
|
||||
int Next(DataIterHandle self) {
|
||||
return static_cast<xgboost::jni::DataIteratorProxy *>(self)->Next();
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
XGB_DLL jint XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
|
||||
jobject jiter,
|
||||
jfloat jmissing,
|
||||
jint jmax_bin, jint jnthread,
|
||||
jlongArray jout) {
|
||||
xgboost::jni::DataIteratorProxy proxy(jiter);
|
||||
DMatrixHandle result;
|
||||
auto ret = XGDeviceQuantileDMatrixCreateFromCallback(
|
||||
&proxy, proxy.GetDMatrixHandle(), Reset, Next, jmissing, jnthread,
|
||||
jmax_bin, &result);
|
||||
setHandle(jenv, jout, result);
|
||||
return ret;
|
||||
}
|
||||
} // namespace jni
|
||||
} // namespace xgboost
|
||||
@ -1 +0,0 @@
|
||||
../../xgboost4j/src/test
|
||||
@ -0,0 +1,127 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.gpu.java;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import junit.framework.TestCase;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
import ai.rapids.cudf.DType;
|
||||
import ai.rapids.cudf.Schema;
|
||||
import ai.rapids.cudf.Table;
|
||||
import ai.rapids.cudf.ColumnVector;
|
||||
import ai.rapids.cudf.CSVOptions;
|
||||
import ml.dmlc.xgboost4j.java.Booster;
|
||||
import ml.dmlc.xgboost4j.java.ColumnBatch;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.DeviceQuantileDMatrix;
|
||||
import ml.dmlc.xgboost4j.java.XGBoost;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
|
||||
/**
|
||||
* Tests the BoosterTest trained by DMatrix
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public class BoosterTest {
|
||||
|
||||
@Test
|
||||
public void testBooster() throws XGBoostError {
|
||||
String trainingDataPath = "../../demo/data/veterans_lung_cancer.csv";
|
||||
Schema schema = Schema.builder()
|
||||
.column(DType.FLOAT32, "A")
|
||||
.column(DType.FLOAT32, "B")
|
||||
.column(DType.FLOAT32, "C")
|
||||
.column(DType.FLOAT32, "D")
|
||||
|
||||
.column(DType.FLOAT32, "E")
|
||||
.column(DType.FLOAT32, "F")
|
||||
.column(DType.FLOAT32, "G")
|
||||
.column(DType.FLOAT32, "H")
|
||||
|
||||
.column(DType.FLOAT32, "I")
|
||||
.column(DType.FLOAT32, "J")
|
||||
.column(DType.FLOAT32, "K")
|
||||
.column(DType.FLOAT32, "L")
|
||||
|
||||
.column(DType.FLOAT32, "label")
|
||||
.build();
|
||||
CSVOptions opts = CSVOptions.builder()
|
||||
.hasHeader().build();
|
||||
|
||||
int maxBin = 16;
|
||||
int round = 100;
|
||||
//set params
|
||||
Map<String, Object> paramMap = new HashMap<String, Object>() {
|
||||
{
|
||||
put("max_depth", 2);
|
||||
put("objective", "binary:logistic");
|
||||
put("num_round", round);
|
||||
put("num_workers", 1);
|
||||
put("tree_method", "gpu_hist");
|
||||
put("predictor", "gpu_predictor");
|
||||
put("max_bin", maxBin);
|
||||
}
|
||||
};
|
||||
|
||||
try (Table tmpTable = Table.readCSV(schema, opts, new File(trainingDataPath))) {
|
||||
ColumnVector[] df = new ColumnVector[12];
|
||||
for (int i = 0; i < 12; ++i) {
|
||||
df[i] = tmpTable.getColumn(i);
|
||||
}
|
||||
try (Table X = new Table(df);) {
|
||||
ColumnVector[] labels = new ColumnVector[1];
|
||||
labels[0] = tmpTable.getColumn(12);
|
||||
|
||||
try (Table y = new Table(labels);) {
|
||||
|
||||
CudfColumnBatch batch = new CudfColumnBatch(X, y, null, null);
|
||||
CudfColumn labelColumn = CudfColumn.from(tmpTable.getColumn(12));
|
||||
|
||||
//set watchList
|
||||
HashMap<String, DMatrix> watches = new HashMap<>();
|
||||
|
||||
DMatrix dMatrix1 = new DMatrix(batch, Float.NaN, 1);
|
||||
dMatrix1.setLabel(labelColumn);
|
||||
watches.put("train", dMatrix1);
|
||||
Booster model1 = XGBoost.train(dMatrix1, paramMap, round, watches, null, null);
|
||||
|
||||
List<ColumnBatch> tables = new LinkedList<>();
|
||||
tables.add(batch);
|
||||
DMatrix incrementalDMatrix = new DeviceQuantileDMatrix(tables.iterator(), Float.NaN, maxBin, 1);
|
||||
//set watchList
|
||||
HashMap<String, DMatrix> watches1 = new HashMap<>();
|
||||
watches1.put("train", incrementalDMatrix);
|
||||
Booster model2 = XGBoost.train(incrementalDMatrix, paramMap, round, watches1, null, null);
|
||||
|
||||
float[][] predicat1 = model1.predict(dMatrix1);
|
||||
float[][] predicat2 = model2.predict(dMatrix1);
|
||||
|
||||
for (int i = 0; i < tmpTable.getRowCount(); i++) {
|
||||
TestCase.assertTrue(predicat1[i][0] - predicat2[i][0] < 1e-6);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@ -0,0 +1,123 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.gpu.java;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
|
||||
import junit.framework.TestCase;
|
||||
|
||||
import com.google.common.primitives.Floats;
|
||||
|
||||
import org.apache.commons.lang.ArrayUtils;
|
||||
import org.junit.Test;
|
||||
|
||||
import ai.rapids.cudf.Table;
|
||||
import ml.dmlc.xgboost4j.java.DMatrix;
|
||||
import ml.dmlc.xgboost4j.java.DeviceQuantileDMatrix;
|
||||
import ml.dmlc.xgboost4j.java.ColumnBatch;
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError;
|
||||
|
||||
/**
|
||||
* Test suite for DMatrix based on GPU
|
||||
*/
|
||||
public class DMatrixTest {
|
||||
|
||||
@Test
|
||||
public void testCreateFromArrayInterfaceColumns() {
|
||||
Float[] labelFloats = new Float[]{2f, 4f, 6f, 8f, 10f};
|
||||
|
||||
Throwable ex = null;
|
||||
try (
|
||||
Table X = new Table.TestBuilder().column(1.f, null, 5.f, 7.f, 9.f).build();
|
||||
Table y = new Table.TestBuilder().column(labelFloats).build();
|
||||
Table w = new Table.TestBuilder().column(labelFloats).build();
|
||||
Table margin = new Table.TestBuilder().column(labelFloats).build();) {
|
||||
|
||||
CudfColumnBatch cudfDataFrame = new CudfColumnBatch(X, y, w, null);
|
||||
|
||||
CudfColumn labelColumn = CudfColumn.from(y.getColumn(0));
|
||||
CudfColumn weightColumn = CudfColumn.from(w.getColumn(0));
|
||||
CudfColumn baseMarginColumn = CudfColumn.from(margin.getColumn(0));
|
||||
|
||||
DMatrix dMatrix = new DMatrix(cudfDataFrame, 0, 1);
|
||||
dMatrix.setLabel(labelColumn);
|
||||
dMatrix.setWeight(weightColumn);
|
||||
dMatrix.setBaseMargin(baseMarginColumn);
|
||||
|
||||
float[] anchor = convertFloatTofloat(labelFloats);
|
||||
float[] label = dMatrix.getLabel();
|
||||
float[] weight = dMatrix.getWeight();
|
||||
float[] baseMargin = dMatrix.getBaseMargin();
|
||||
|
||||
TestCase.assertTrue(Arrays.equals(anchor, label));
|
||||
TestCase.assertTrue(Arrays.equals(anchor, weight));
|
||||
TestCase.assertTrue(Arrays.equals(anchor, baseMargin));
|
||||
} catch (Throwable e) {
|
||||
ex = e;
|
||||
e.printStackTrace();
|
||||
}
|
||||
TestCase.assertNull(ex);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCreateFromColumnDataIterator() throws XGBoostError {
|
||||
|
||||
Float[] label1 = {25f, 21f, 22f, 20f, 24f};
|
||||
Float[] weight1 = {1.3f, 2.31f, 0.32f, 3.3f, 1.34f};
|
||||
Float[] baseMargin1 = {1.2f, 0.2f, 1.3f, 2.4f, 3.5f};
|
||||
|
||||
Float[] label2 = {9f, 5f, 4f, 10f, 12f};
|
||||
Float[] weight2 = {3.0f, 1.3f, 3.2f, 0.3f, 1.34f};
|
||||
Float[] baseMargin2 = {0.2f, 2.5f, 3.1f, 4.4f, 2.2f};
|
||||
|
||||
try (
|
||||
Table X_0 = new Table.TestBuilder()
|
||||
.column(1.2f, null, 5.2f, 7.2f, 9.2f)
|
||||
.column(0.2f, 0.4f, 0.6f, 2.6f, 0.10f)
|
||||
.build();
|
||||
Table y_0 = new Table.TestBuilder().column(label1).build();
|
||||
Table w_0 = new Table.TestBuilder().column(weight1).build();
|
||||
Table m_0 = new Table.TestBuilder().column(baseMargin1).build();
|
||||
Table X_1 = new Table.TestBuilder().column(11.2f, 11.2f, 15.2f, 17.2f, 19.2f)
|
||||
.column(1.2f, 1.4f, null, 12.6f, 10.10f).build();
|
||||
Table y_1 = new Table.TestBuilder().column(label2).build();
|
||||
Table w_1 = new Table.TestBuilder().column(weight2).build();
|
||||
Table m_1 = new Table.TestBuilder().column(baseMargin2).build();) {
|
||||
|
||||
List<ColumnBatch> tables = new LinkedList<>();
|
||||
|
||||
tables.add(new CudfColumnBatch(X_0, y_0, w_0, m_0));
|
||||
tables.add(new CudfColumnBatch(X_1, y_1, w_1, m_1));
|
||||
|
||||
DMatrix dmat = new DeviceQuantileDMatrix(tables.iterator(), 0.0f, 8, 1);
|
||||
|
||||
float[] anchorLabel = convertFloatTofloat((Float[]) ArrayUtils.addAll(label1, label2));
|
||||
float[] anchorWeight = convertFloatTofloat((Float[]) ArrayUtils.addAll(weight1, weight2));
|
||||
float[] anchorBaseMargin = convertFloatTofloat((Float[]) ArrayUtils.addAll(baseMargin1, baseMargin2));
|
||||
|
||||
TestCase.assertTrue(Arrays.equals(anchorLabel, dmat.getLabel()));
|
||||
TestCase.assertTrue(Arrays.equals(anchorWeight, dmat.getWeight()));
|
||||
TestCase.assertTrue(Arrays.equals(anchorBaseMargin, dmat.getBaseMargin()));
|
||||
}
|
||||
}
|
||||
|
||||
private float[] convertFloatTofloat(Float[] in) {
|
||||
return Floats.toArray(Arrays.asList(in));
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,40 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
/**
|
||||
* The abstracted XGBoost Column to get the cuda array interface which is used to
|
||||
* set the information for DMatrix.
|
||||
*
|
||||
*/
|
||||
public abstract class Column implements AutoCloseable {
|
||||
|
||||
/**
|
||||
* Get the cuda array interface json string for the Column which can be representing
|
||||
* weight, label, base margin column.
|
||||
*
|
||||
* This API will be called by
|
||||
* {@link DMatrix#setLabel(Column)}
|
||||
* {@link DMatrix#setWeight(Column)}
|
||||
* {@link DMatrix#setBaseMargin(Column)}
|
||||
*/
|
||||
public abstract String getArrayInterfaceJson();
|
||||
|
||||
@Override
|
||||
public void close() throws Exception {}
|
||||
|
||||
}
|
||||
@ -0,0 +1,93 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.util.Iterator;
|
||||
|
||||
/**
|
||||
* The abstracted XGBoost ColumnBatch to get array interface from columnar data format.
|
||||
* For example, the cuDF dataframe which employs apache arrow specification.
|
||||
*/
|
||||
public abstract class ColumnBatch implements AutoCloseable {
|
||||
/**
|
||||
* Get the cuda array interface json string for the whole ColumnBatch including
|
||||
* the must-have feature, label columns and the optional weight, base margin columns.
|
||||
*
|
||||
* This function is be called by native code during iteration and can be made as private
|
||||
* method. We keep it as public simply to silent the linter.
|
||||
*/
|
||||
public final String getArrayInterfaceJson() {
|
||||
|
||||
StringBuilder builder = new StringBuilder();
|
||||
builder.append("{");
|
||||
String featureStr = this.getFeatureArrayInterface();
|
||||
if (featureStr == null || featureStr.isEmpty()) {
|
||||
throw new RuntimeException("Feature array interface must not be empty");
|
||||
} else {
|
||||
builder.append("\"features_str\":" + featureStr);
|
||||
}
|
||||
|
||||
String labelStr = this.getLabelsArrayInterface();
|
||||
if (labelStr == null || labelStr.isEmpty()) {
|
||||
throw new RuntimeException("Label array interface must not be empty");
|
||||
} else {
|
||||
builder.append(",\"label_str\":" + labelStr);
|
||||
}
|
||||
|
||||
String weightStr = getWeightsArrayInterface();
|
||||
if (weightStr != null && ! weightStr.isEmpty()) {
|
||||
builder.append(",\"weight_str\":" + weightStr);
|
||||
}
|
||||
|
||||
String baseMarginStr = getBaseMarginsArrayInterface();
|
||||
if (baseMarginStr != null && ! baseMarginStr.isEmpty()) {
|
||||
builder.append(",\"basemargin_str\":" + baseMarginStr);
|
||||
}
|
||||
|
||||
builder.append("}");
|
||||
return builder.toString();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the cuda array interface of the feature columns.
|
||||
* The returned value must not be null or empty
|
||||
*/
|
||||
public abstract String getFeatureArrayInterface();
|
||||
|
||||
/**
|
||||
* Get the cuda array interface of the label columns.
|
||||
* The returned value must not be null or empty if we're creating
|
||||
* {@link DeviceQuantileDMatrix#DeviceQuantileDMatrix(Iterator, float, int, int)}
|
||||
*/
|
||||
public abstract String getLabelsArrayInterface();
|
||||
|
||||
/**
|
||||
* Get the cuda array interface of the weight columns.
|
||||
* The returned value can be null or empty
|
||||
*/
|
||||
public abstract String getWeightsArrayInterface();
|
||||
|
||||
/**
|
||||
* Get the cuda array interface of the base margin columns.
|
||||
* The returned value can be null or empty
|
||||
*/
|
||||
public abstract String getBaseMarginsArrayInterface();
|
||||
|
||||
@Override
|
||||
public void close() throws Exception {}
|
||||
|
||||
}
|
||||
@ -177,6 +177,64 @@ public class DMatrix {
|
||||
this.handle = handle;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create the normal DMatrix from column array interface
|
||||
* @param columnBatch the XGBoost ColumnBatch to provide the cuda array interface
|
||||
* of feature columns
|
||||
* @param missing missing value
|
||||
* @param nthread threads number
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public DMatrix(ColumnBatch columnBatch, float missing, int nthread) throws XGBoostError {
|
||||
long[] out = new long[1];
|
||||
String json = columnBatch.getFeatureArrayInterface();
|
||||
if (json == null || json.isEmpty()) {
|
||||
throw new XGBoostError("Expecting non-empty feature columns' array interface");
|
||||
}
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromArrayInterfaceColumns(
|
||||
json, missing, nthread, out));
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
/**
|
||||
* Set label of DMatrix from cuda array interface
|
||||
*
|
||||
* @param column the XGBoost Column to provide the cuda array interface
|
||||
* of label column
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setLabel(Column column) throws XGBoostError {
|
||||
setXGBDMatrixInfo("label", column.getArrayInterfaceJson());
|
||||
}
|
||||
|
||||
/**
|
||||
* Set weight of DMatrix from cuda array interface
|
||||
*
|
||||
* @param column the XGBoost Column to provide the cuda array interface
|
||||
* of weight column
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setWeight(Column column) throws XGBoostError {
|
||||
setXGBDMatrixInfo("weight", column.getArrayInterfaceJson());
|
||||
}
|
||||
|
||||
/**
|
||||
* Set base margin of DMatrix from cuda array interface
|
||||
*
|
||||
* @param column the XGBoost Column to provide the cuda array interface
|
||||
* of base margin column
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setBaseMargin(Column column) throws XGBoostError {
|
||||
setXGBDMatrixInfo("base_margin", column.getArrayInterfaceJson());
|
||||
}
|
||||
|
||||
private void setXGBDMatrixInfo(String type, String json) throws XGBoostError {
|
||||
if (json == null || json.isEmpty()) {
|
||||
throw new XGBoostError("Empty " + type + " columns' array interface");
|
||||
}
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetInfoFromInterface(handle, type, json));
|
||||
}
|
||||
|
||||
/**
|
||||
* set label of dmatrix
|
||||
|
||||
@ -0,0 +1,68 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.util.Iterator;
|
||||
|
||||
/**
|
||||
* DeviceQuantileDMatrix will only be used to train
|
||||
*/
|
||||
public class DeviceQuantileDMatrix extends DMatrix {
|
||||
/**
|
||||
* Create DeviceQuantileDMatrix from iterator based on the cuda array interface
|
||||
* @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array interface
|
||||
* @param missing the missing value
|
||||
* @param maxBin the max bin
|
||||
* @param nthread the parallelism
|
||||
* @throws XGBoostError
|
||||
*/
|
||||
public DeviceQuantileDMatrix(
|
||||
Iterator<ColumnBatch> iter,
|
||||
float missing,
|
||||
int maxBin,
|
||||
int nthread) throws XGBoostError {
|
||||
super(0);
|
||||
long[] out = new long[1];
|
||||
XGBoostJNI.checkCall(XGBoostJNI.XGDeviceQuantileDMatrixCreateFromCallback(
|
||||
iter, missing, maxBin, nthread, out));
|
||||
handle = out[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setLabel(Column column) throws XGBoostError {
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setLabel.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setWeight(Column column) throws XGBoostError {
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setWeight.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setBaseMargin(Column column) throws XGBoostError {
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setLabel(float[] labels) throws XGBoostError {
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setLabel.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setWeight(float[] weights) throws XGBoostError {
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setWeight.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setBaseMargin(float[] baseMargin) throws XGBoostError {
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setBaseMargin(float[][] baseMargin) throws XGBoostError {
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setGroup(int[] group) throws XGBoostError {
|
||||
throw new XGBoostError("DeviceQuantileDMatrix does not support setGroup.");
|
||||
}
|
||||
}
|
||||
@ -134,4 +134,13 @@ class XGBoostJNI {
|
||||
// This JNI function does not support the callback function for data preparation yet.
|
||||
final static native int RabitAllreduce(ByteBuffer sendrecvbuf, int count,
|
||||
int enum_dtype, int enum_op);
|
||||
|
||||
public final static native int XGDMatrixSetInfoFromInterface(
|
||||
long handle, String field, String json);
|
||||
|
||||
public final static native int XGDeviceQuantileDMatrixCreateFromCallback(
|
||||
java.util.Iterator<ColumnBatch> iter, float missing, int nthread, int maxBin, long[] out);
|
||||
|
||||
public final static native int XGDMatrixCreateFromArrayInterfaceColumns(
|
||||
String featureJson, float missing, int nthread, long[] out);
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
#include <xgboost/c_api.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/json.h>
|
||||
#include "./xgboost4j.h"
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
@ -43,12 +44,14 @@ void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) {
|
||||
jenv->SetLongArrayRegion(jhandle, 0, 1, &out);
|
||||
}
|
||||
|
||||
// global JVM
|
||||
static JavaVM* global_jvm = nullptr;
|
||||
JavaVM*& GlobalJvm() {
|
||||
static JavaVM* vm;
|
||||
return vm;
|
||||
}
|
||||
|
||||
// overrides JNI on load
|
||||
jint JNI_OnLoad(JavaVM *vm, void *reserved) {
|
||||
global_jvm = vm;
|
||||
GlobalJvm() = vm;
|
||||
return JNI_VERSION_1_6;
|
||||
}
|
||||
|
||||
@ -58,9 +61,9 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
|
||||
DataHolderHandle set_function_handle) {
|
||||
jobject jiter = static_cast<jobject>(data_handle);
|
||||
JNIEnv* jenv;
|
||||
int jni_status = global_jvm->GetEnv((void **)&jenv, JNI_VERSION_1_6);
|
||||
int jni_status = GlobalJvm()->GetEnv((void **)&jenv, JNI_VERSION_1_6);
|
||||
if (jni_status == JNI_EDETACHED) {
|
||||
global_jvm->AttachCurrentThread(reinterpret_cast<void **>(&jenv), nullptr);
|
||||
GlobalJvm()->AttachCurrentThread(reinterpret_cast<void **>(&jenv), nullptr);
|
||||
} else {
|
||||
CHECK(jni_status == JNI_OK);
|
||||
}
|
||||
@ -148,13 +151,13 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
|
||||
jenv->DeleteLocalRef(iterClass);
|
||||
// only detach if it is a async call.
|
||||
if (jni_status == JNI_EDETACHED) {
|
||||
global_jvm->DetachCurrentThread();
|
||||
GlobalJvm()->DetachCurrentThread();
|
||||
}
|
||||
return ret_value;
|
||||
} catch(dmlc::Error const& e) {
|
||||
// only detach if it is a async call.
|
||||
if (jni_status == JNI_EDETACHED) {
|
||||
global_jvm->DetachCurrentThread();
|
||||
GlobalJvm()->DetachCurrentThread();
|
||||
}
|
||||
LOG(FATAL) << e.what();
|
||||
return -1;
|
||||
@ -968,3 +971,71 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
namespace xgboost {
|
||||
namespace jni {
|
||||
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
|
||||
jobject jiter,
|
||||
jfloat jmissing,
|
||||
jint jmax_bin, jint jnthread,
|
||||
jlongArray jout);
|
||||
} // namespace jni
|
||||
} // namespace xgboost
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDeviceQuantileDMatrixCreateFromCallback
|
||||
* Signature: (Ljava/util/Iterator;FII[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDMatrixCreateFromCallback
|
||||
(JNIEnv *jenv, jclass jcls, jobject jiter, jfloat jmissing, jint jmax_bin,
|
||||
jint jnthread, jlongArray jout) {
|
||||
return xgboost::jni::XGDeviceQuantileDMatrixCreateFromCallbackImpl(
|
||||
jenv, jcls, jiter, jmissing, jmax_bin, jnthread, jout);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixSetInfoFromInterface
|
||||
* Signature: (JLjava/lang/String;Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFromInterface
|
||||
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jstring jjson_columns) {
|
||||
DMatrixHandle handle = (DMatrixHandle) jhandle;
|
||||
const char* field = jenv->GetStringUTFChars(jfield, 0);
|
||||
const char* cjson_columns = jenv->GetStringUTFChars(jjson_columns, 0);
|
||||
|
||||
int ret = XGDMatrixSetInfoFromInterface(handle, field, cjson_columns);
|
||||
JVM_CHECK_CALL(ret);
|
||||
//release
|
||||
if (field) jenv->ReleaseStringUTFChars(jfield, field);
|
||||
if (cjson_columns) jenv->ReleaseStringUTFChars(jjson_columns, cjson_columns);
|
||||
return ret;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
|
||||
* Method: XGDMatrixCreateFromArrayInterfaceColumns
|
||||
* Signature: (Ljava/lang/String;FI[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromArrayInterfaceColumns
|
||||
(JNIEnv *jenv, jclass jcls, jstring jjson_columns, jfloat jmissing, jint jnthread, jlongArray jout) {
|
||||
DMatrixHandle result;
|
||||
const char* cjson_columns = jenv->GetStringUTFChars(jjson_columns, nullptr);
|
||||
xgboost::Json config{xgboost::Object{}};
|
||||
auto missing = static_cast<float>(jmissing);
|
||||
auto n_threads = static_cast<int32_t>(jnthread);
|
||||
config["missing"] = xgboost::Number(missing);
|
||||
config["nthread"] = xgboost::Integer(n_threads);
|
||||
std::string config_str;
|
||||
xgboost::Json::Dump(config, &config_str);
|
||||
int ret = XGDMatrixCreateFromCudaColumnar(cjson_columns, config_str.c_str(),
|
||||
&result);
|
||||
JVM_CHECK_CALL(ret);
|
||||
if (cjson_columns) {
|
||||
jenv->ReleaseStringUTFChars(jjson_columns, cjson_columns);
|
||||
}
|
||||
|
||||
setHandle(jenv, jout, result);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@ -335,6 +335,30 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
|
||||
(JNIEnv *, jclass, jobject, jint, jint, jint);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_GpuXGBoostJNI
|
||||
* Method: XGDMatrixSetInfoFromInterface
|
||||
* Signature: (JLjava/lang/String;Ljava/lang/String;)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFromInterface
|
||||
(JNIEnv *, jclass, jlong, jstring, jstring);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_GpuXGBoostJNI
|
||||
* Method: XGDeviceQuantileDMatrixCreateFromCallback
|
||||
* Signature: (Ljava/util/Iterator;FII[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDMatrixCreateFromCallback
|
||||
(JNIEnv *, jclass, jobject, jfloat, jint, jint, jlongArray);
|
||||
|
||||
/*
|
||||
* Class: ml_dmlc_xgboost4j_java_GpuXGBoostJNI
|
||||
* Method: XGDMatrixCreateFromArrayInterfaceColumns
|
||||
* Signature: (Ljava/lang/String;FI[J)I
|
||||
*/
|
||||
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromArrayInterfaceColumns
|
||||
(JNIEnv *, jclass, jstring, jfloat, jint, jlongArray);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -22,7 +22,7 @@ RUN \
|
||||
# NCCL2 (License: https://docs.nvidia.com/deeplearning/sdk/nccl-sla/index.html)
|
||||
RUN \
|
||||
export CUDA_SHORT=`echo $CUDA_VERSION_ARG | grep -o -E '[0-9]+\.[0-9]'` && \
|
||||
export NCCL_VERSION=2.4.8-1 && \
|
||||
export NCCL_VERSION=2.8.3-1 && \
|
||||
wget -nv -nc https://developer.download.nvidia.com/compute/machine-learning/repos/rhel7/x86_64/nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm && \
|
||||
rpm -i nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm && \
|
||||
yum -y update && \
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user