[jvm-packages][xgboost4j-gpu] Support GPU dataframe and DeviceQuantileDMatrix (#7195)

Following classes are added to support dataframe in java binding:

- `Column` is an abstract type for a single column in tabular data.
- `ColumnBatch` is an abstract type for dataframe.

- `CuDFColumn` is an implementaiton of `Column` that consume cuDF column
- `CudfColumnBatch` is an implementation of `ColumnBatch` that consumes cuDF dataframe.

- `DeviceQuantileDMatrix` is the interface for quantized data.

The Java implementation mimics the Python interface and uses `__cuda_array_interface__` protocol for memory indexing.  One difference is on JVM package, the data batch is staged on the host as java iterators cannot be reset.

Co-authored-by: jiamingy <jm.yuan@outlook.com>
This commit is contained in:
Bobby Wang 2021-09-24 14:25:00 +08:00 committed by GitHub
parent d27a427dc5
commit 0ee11dac77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 1388 additions and 18 deletions

2
Jenkinsfile vendored
View File

@ -64,7 +64,7 @@ pipeline {
// The build-gpu-* builds below use Ubuntu image // The build-gpu-* builds below use Ubuntu image
'build-gpu-cuda11.0': { BuildCUDA(cuda_version: '11.0', build_rmm: true) }, 'build-gpu-cuda11.0': { BuildCUDA(cuda_version: '11.0', build_rmm: true) },
'build-gpu-rpkg': { BuildRPackageWithCUDA(cuda_version: '10.1') }, 'build-gpu-rpkg': { BuildRPackageWithCUDA(cuda_version: '10.1') },
'build-jvm-packages-gpu-cuda10.1': { BuildJVMPackagesWithCUDA(spark_version: '3.0.0', cuda_version: '10.1') }, 'build-jvm-packages-gpu-cuda10.1': { BuildJVMPackagesWithCUDA(spark_version: '3.0.0', cuda_version: '11.0') },
'build-jvm-packages': { BuildJVMPackages(spark_version: '3.0.0') }, 'build-jvm-packages': { BuildJVMPackages(spark_version: '3.0.0') },
'build-jvm-doc': { BuildJVMDoc() } 'build-jvm-doc': { BuildJVMDoc() }
]) ])

View File

@ -1,10 +1,20 @@
find_package(JNI REQUIRED) find_package(JNI REQUIRED)
add_library(xgboost4j SHARED list(APPEND JVM_SOURCES
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j/src/native/xgboost4j.cpp) ${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j/src/native/xgboost4j.cpp
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cpp)
if (USE_CUDA)
list(APPEND JVM_SOURCES
${PROJECT_SOURCE_DIR}/jvm-packages/xgboost4j-gpu/src/native/xgboost4j-gpu.cu)
endif (USE_CUDA)
add_library(xgboost4j SHARED ${JVM_SOURCES} ${XGBOOST_OBJ_SOURCES})
if (ENABLE_ALL_WARNINGS) if (ENABLE_ALL_WARNINGS)
target_compile_options(xgboost4j PUBLIC -Wall -Wextra) target_compile_options(xgboost4j PUBLIC -Wall -Wextra)
endif (ENABLE_ALL_WARNINGS) endif (ENABLE_ALL_WARNINGS)
target_link_libraries(xgboost4j PRIVATE objxgboost) target_link_libraries(xgboost4j PRIVATE objxgboost)
target_include_directories(xgboost4j target_include_directories(xgboost4j
PRIVATE PRIVATE
@ -15,8 +25,4 @@ target_include_directories(xgboost4j
${PROJECT_SOURCE_DIR}/rabit/include) ${PROJECT_SOURCE_DIR}/rabit/include)
set_output_directory(xgboost4j ${PROJECT_SOURCE_DIR}/lib) set_output_directory(xgboost4j ${PROJECT_SOURCE_DIR}/lib)
set_target_properties(
xgboost4j PROPERTIES
CXX_STANDARD 14
CXX_STANDARD_REQUIRED ON)
target_link_libraries(xgboost4j PRIVATE ${JAVA_JVM_LIBRARY}) target_link_libraries(xgboost4j PRIVATE ${JAVA_JVM_LIBRARY})

View File

@ -12,7 +12,24 @@
<version>1.5.0-SNAPSHOT</version> <version>1.5.0-SNAPSHOT</version>
<packaging>jar</packaging> <packaging>jar</packaging>
<properties>
<cudf.version>21.08.2</cudf.version>
<cudf.classifier>cuda11</cudf.classifier>
</properties>
<dependencies> <dependencies>
<dependency>
<groupId>ai.rapids</groupId>
<artifactId>cudf</artifactId>
<version>${cudf.version}</version>
<classifier>${cudf.classifier}</classifier>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>2.10.5.1</version>
</dependency>
<dependency> <dependency>
<groupId>org.apache.hadoop</groupId> <groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-hdfs</artifactId> <artifactId>hadoop-hdfs</artifactId>

View File

@ -1 +0,0 @@
../../../xgboost4j/src/main/java/

View File

@ -0,0 +1,110 @@
/*
Copyright (c) 2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.gpu.java;
import ai.rapids.cudf.BaseDeviceMemoryBuffer;
import ai.rapids.cudf.BufferType;
import ai.rapids.cudf.ColumnVector;
import ai.rapids.cudf.DType;
import ml.dmlc.xgboost4j.java.Column;
/**
* This class is composing of base data with Apache Arrow format from Cudf ColumnVector.
* It will be used to generate the cuda array interface.
*/
class CudfColumn extends Column {
private final long dataPtr; // gpu data buffer address
private final long shape; // row count
private final long validPtr; // gpu valid buffer address
private final int typeSize; // type size in bytes
private final String typeStr; // follow array interface spec
private final long nullCount; // null count
private String arrayInterface = null; // the cuda array interface
public static CudfColumn from(ColumnVector cv) {
BaseDeviceMemoryBuffer dataBuffer = cv.getDeviceBufferFor(BufferType.DATA);
BaseDeviceMemoryBuffer validBuffer = cv.getDeviceBufferFor(BufferType.VALIDITY);
long validPtr = 0;
if (validBuffer != null) {
validPtr = validBuffer.getAddress();
}
DType dType = cv.getType();
String typeStr = "";
if (dType == DType.FLOAT32 || dType == DType.FLOAT64 ||
dType == DType.TIMESTAMP_DAYS || dType == DType.TIMESTAMP_MICROSECONDS ||
dType == DType.TIMESTAMP_MILLISECONDS || dType == DType.TIMESTAMP_NANOSECONDS ||
dType == DType.TIMESTAMP_SECONDS) {
typeStr = "<f" + dType.getSizeInBytes();
} else if (dType == DType.BOOL8 || dType == DType.INT8 || dType == DType.INT16 ||
dType == DType.INT32 || dType == DType.INT64) {
typeStr = "<i" + dType.getSizeInBytes();
} else {
// Unsupported type.
throw new IllegalArgumentException("Unsupported data type: " + dType);
}
return new CudfColumn(dataBuffer.getAddress(), cv.getRowCount(), validPtr,
dType.getSizeInBytes(), typeStr, cv.getNullCount());
}
private CudfColumn(long dataPtr, long shape, long validPtr, int typeSize, String typeStr,
long nullCount) {
this.dataPtr = dataPtr;
this.shape = shape;
this.validPtr = validPtr;
this.typeSize = typeSize;
this.typeStr = typeStr;
this.nullCount = nullCount;
}
@Override
public String getArrayInterfaceJson() {
// There is no race-condition
if (arrayInterface == null) {
arrayInterface = CudfUtils.buildArrayInterface(this);
}
return arrayInterface;
}
public long getDataPtr() {
return dataPtr;
}
public long getShape() {
return shape;
}
public long getValidPtr() {
return validPtr;
}
public int getTypeSize() {
return typeSize;
}
public String getTypeStr() {
return typeStr;
}
public long getNullCount() {
return nullCount;
}
}

View File

@ -0,0 +1,88 @@
/*
Copyright (c) 2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.gpu.java;
import java.util.stream.IntStream;
import ai.rapids.cudf.Table;
import ml.dmlc.xgboost4j.java.ColumnBatch;
/**
* Class to wrap CUDF Table to generate the cuda array interface.
*/
public class CudfColumnBatch extends ColumnBatch {
private final Table feature;
private final Table label;
private final Table weight;
private final Table baseMargin;
public CudfColumnBatch(Table feature, Table labels, Table weights, Table baseMargins) {
this.feature = feature;
this.label = labels;
this.weight = weights;
this.baseMargin = baseMargins;
}
@Override
public String getFeatureArrayInterface() {
return getArrayInterface(this.feature);
}
@Override
public String getLabelsArrayInterface() {
return getArrayInterface(this.label);
}
@Override
public String getWeightsArrayInterface() {
return getArrayInterface(this.weight);
}
@Override
public String getBaseMarginsArrayInterface() {
return getArrayInterface(this.baseMargin);
}
@Override
public void close() {
if (feature != null) feature.close();
if (label != null) label.close();
if (weight != null) weight.close();
if (baseMargin != null) baseMargin.close();
}
private String getArrayInterface(Table table) {
if (table == null || table.getNumberOfColumns() == 0) {
return "";
}
return CudfUtils.buildArrayInterface(getAsCudfColumn(table));
}
private CudfColumn[] getAsCudfColumn(Table table) {
if (table == null || table.getNumberOfColumns() == 0) {
// This will never happen.
return new CudfColumn[]{};
}
return IntStream.range(0, table.getNumberOfColumns())
.mapToObj((i) -> table.getColumn(i))
.map(CudfColumn::from)
.toArray(CudfColumn[]::new);
}
}

View File

@ -0,0 +1,100 @@
/*
Copyright (c) 2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.gpu.java;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
/**
* Cudf utilities to build cuda array interface against {@link CudfColumn}
*/
class CudfUtils {
/**
* Build the cuda array interface based on CudfColumn(s)
* @param cudfColumns the CudfColumn(s) to be built
* @return the json format of cuda array interface
*/
public static String buildArrayInterface(CudfColumn... cudfColumns) {
return new Builder().add(cudfColumns).build();
}
// Helper class to build array interface string
private static class Builder {
private JsonNodeFactory nodeFactory = new JsonNodeFactory(false);
private ArrayNode rootArrayNode = nodeFactory.arrayNode();
private Builder add(CudfColumn... columns) {
if (columns == null || columns.length <= 0) {
throw new IllegalArgumentException("At least one ColumnData is required.");
}
for (CudfColumn cd : columns) {
rootArrayNode.add(buildColumnObject(cd));
}
return this;
}
private String build() {
try {
ByteArrayOutputStream bos = new ByteArrayOutputStream();
JsonGenerator jsonGen = new JsonFactory().createGenerator(bos);
new ObjectMapper().writeTree(jsonGen, rootArrayNode);
return bos.toString();
} catch (IOException ie) {
ie.printStackTrace();
throw new RuntimeException("Failed to build array interface. Error: " + ie);
}
}
private ObjectNode buildColumnObject(CudfColumn column) {
if (column.getDataPtr() == 0) {
throw new IllegalArgumentException("Empty column data is NOT accepted!");
}
if (column.getTypeStr() == null || column.getTypeStr().isEmpty()) {
throw new IllegalArgumentException("Empty type string is NOT accepted!");
}
ObjectNode colDataObj = buildMetaObject(column.getDataPtr(), column.getShape(),
column.getTypeStr());
if (column.getValidPtr() != 0 && column.getNullCount() != 0) {
ObjectNode validObj = buildMetaObject(column.getValidPtr(), column.getShape(), "<t1");
colDataObj.set("mask", validObj);
}
return colDataObj;
}
private ObjectNode buildMetaObject(long ptr, long shape, final String typeStr) {
ObjectNode objNode = nodeFactory.objectNode();
ArrayNode shapeNode = objNode.putArray("shape");
shapeNode.add(shape);
ArrayNode dataNode = objNode.putArray("data");
dataNode.add(ptr)
.add(false);
objNode.put("typestr", typeStr)
.put("version", 1);
return objNode;
}
}
}

View File

@ -0,0 +1 @@
../../../../../../../xgboost4j/src/main/java/ml/dmlc/xgboost4j/java

View File

@ -1 +0,0 @@
../../xgboost4j/src/native

View File

@ -0,0 +1,15 @@
#ifndef JVM_UTILS_H_
#define JVM_UTILS_H_
#define JVM_CHECK_CALL(__expr) \
{ \
int __errcode = (__expr); \
if (__errcode != 0) { \
return __errcode; \
} \
}
JavaVM*& GlobalJvm();
void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle);
#endif // JVM_UTILS_H_

View File

@ -0,0 +1,25 @@
//
// Created by bobwang on 2021/9/8.
//
#ifndef XGBOOST_USE_CUDA
#include <jni.h>
#include "../../../../src/common/common.h"
#include "../../../../src/c_api/c_api_error.h"
namespace xgboost {
namespace jni {
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
jobject jiter,
jfloat jmissing,
jint jmax_bin, jint jnthread,
jlongArray jout) {
API_BEGIN();
common::AssertGPUSupport();
API_END();
}
} // namespace jni
} // namespace xgboost
#endif // XGBOOST_USE_CUDA

View File

@ -0,0 +1,398 @@
#include <jni.h>
#include <thrust/system/cuda/experimental/pinned_allocator.h>
#include "../../../../src/common/device_helpers.cuh"
#include "../../../../src/data/array_interface.h"
#include "jvm_utils.h"
#include <xgboost/c_api.h>
namespace xgboost {
namespace jni {
template <typename T, typename Alloc>
T const *RawPtr(std::vector<T, Alloc> const &data) {
return data.data();
}
template <typename T, typename Alloc> T *RawPtr(std::vector<T, Alloc> &data) {
return data.data();
}
template <typename T> T const *RawPtr(dh::device_vector<T> const &data) {
return data.data().get();
}
template <typename T> T *RawPtr(dh::device_vector<T> &data) {
return data.data().get();
}
template <typename T> T CheckJvmCall(T const &v, JNIEnv *jenv) {
if (!v) {
CHECK(jenv->ExceptionOccurred());
jenv->ExceptionDescribe();
}
return v;
}
template <typename VCont>
void CopyColumnMask(xgboost::ArrayInterface const &interface,
std::vector<Json> const &columns, cudaMemcpyKind kind,
size_t c, VCont *p_mask, Json *p_out, cudaStream_t stream) {
auto &mask = *p_mask;
auto &out = *p_out;
auto size = sizeof(typename VCont::value_type) * interface.num_rows *
interface.num_cols;
mask.resize(size);
CHECK(RawPtr(mask));
CHECK(size);
CHECK(interface.valid.Data());
dh::safe_cuda(
cudaMemcpyAsync(RawPtr(mask), interface.valid.Data(), size, kind, stream));
auto const &mask_column = columns[c]["mask"];
out["mask"] = Object();
std::vector<Json> mask_data{
Json{reinterpret_cast<Integer::Int>(RawPtr(mask))},
Json{get<Boolean const>(mask_column["data"][1])}};
out["mask"]["data"] = Array(std::move(mask_data));
if (get<Array const>(mask_column["shape"]).size() == 2) {
std::vector<Json> mask_shape{
Json{get<Integer const>(mask_column["shape"][0])},
Json{get<Integer const>(mask_column["shape"][1])}};
out["mask"]["shape"] = Array(std::move(mask_shape));
} else if (get<Array const>(mask_column["shape"]).size() == 1) {
std::vector<Json> mask_shape{
Json{get<Integer const>(mask_column["shape"][0])}};
out["mask"]["shape"] = Array(std::move(mask_shape));
} else {
LOG(FATAL) << "Invalid shape of mask";
}
out["mask"]["typestr"] = String("<t1");
out["mask"]["version"] = Integer(1);
}
template <typename DCont, typename VCont>
void CopyInterface(std::vector<xgboost::ArrayInterface> &interface_arr,
std::vector<Json> const &columns, cudaMemcpyKind kind,
std::vector<DCont> *p_data, std::vector<VCont> *p_mask,
std::vector<xgboost::Json> *p_out, cudaStream_t stream) {
p_data->resize(interface_arr.size());
p_mask->resize(interface_arr.size());
p_out->resize(interface_arr.size());
for (size_t c = 0; c < interface_arr.size(); ++c) {
auto &interface = interface_arr.at(c);
size_t element_size = interface.ElementSize();
size_t size = element_size * interface.num_rows * interface.num_cols;
auto &data = (*p_data)[c];
auto &mask = (*p_mask)[c];
data.resize(size);
dh::safe_cuda(cudaMemcpyAsync(RawPtr(data), interface.data, size, kind, stream));
auto &out = (*p_out)[c];
out = Object();
std::vector<Json> j_data{
Json{Integer(reinterpret_cast<Integer::Int>(RawPtr(data)))},
Json{Boolean{false}}};
out["data"] = Array(std::move(j_data));
out["shape"] = Array(std::vector<Json>{Json(Integer(interface.num_rows)),
Json(Integer(interface.num_cols))});
if (interface.valid.Data()) {
CopyColumnMask(interface, columns, kind, c, &mask, &out, stream);
}
out["typestr"] = String("<f4");
out["version"] = Integer(1);
}
}
void CopyMetaInfo(Json *p_interface, dh::device_vector<float> *out, cudaStream_t stream) {
auto &j_interface = *p_interface;
CHECK_EQ(get<Array const>(j_interface).size(), 1);
auto object = get<Object>(get<Array>(j_interface)[0]);
ArrayInterface interface(object);
out->resize(interface.num_rows);
size_t element_size = interface.ElementSize();
size_t size = element_size * interface.num_rows;
dh::safe_cuda(cudaMemcpyAsync(RawPtr(*out), interface.data, size,
cudaMemcpyDeviceToDevice, stream));
j_interface[0]["data"][0] = reinterpret_cast<Integer::Int>(RawPtr(*out));
}
template <typename DCont, typename VCont> struct DataFrame {
std::vector<DCont> data;
std::vector<VCont> valid;
std::vector<Json> interfaces;
};
class DataIteratorProxy {
DMatrixHandle proxy_;
JNIEnv *jenv_;
int jni_status_;
jobject jiter_;
bool cache_on_host_{true}; // TODO(Bobby): Make this optional.
template <typename T>
using Alloc = thrust::system::cuda::experimental::pinned_allocator<T>;
template <typename U>
using HostVector = std::vector<U, Alloc<U>>;
// This vector is created for staging device data on host to save GPU memory.
// When space is not of concern, we can stage them on device memory directly.
std::vector<
std::unique_ptr<DataFrame<HostVector<char>, HostVector<std::uint8_t>>>>
host_columns_;
// TODO(Bobby): Use this instead of `host_columns_` if staging is not
// required.
std::vector<std::unique_ptr<DataFrame<dh::device_vector<char>,
dh::device_vector<std::uint8_t>>>>
device_columns_;
// Staging area for metainfo.
// TODO(Bobby): label_upper_bound, label_lower_bound, group.
std::vector<std::unique_ptr<dh::device_vector<float>>> labels_;
std::vector<std::unique_ptr<dh::device_vector<float>>> weights_;
std::vector<std::unique_ptr<dh::device_vector<float>>> base_margins_;
std::vector<Json> label_interfaces_;
std::vector<Json> weight_interfaces_;
std::vector<Json> margin_interfaces_;
size_t it_{0};
size_t n_batches_{0};
bool initialized_{false};
jobject last_batch_ {nullptr};
// Temp buffer on device, each `dh::device_vector` represents a column
// from cudf.
std::vector<dh::device_vector<char>> staging_data_;
std::vector<dh::device_vector<uint8_t>> staging_mask_;
cudaStream_t copy_stream_;
public:
explicit DataIteratorProxy(jobject jiter, bool cache_on_host = true)
: jiter_{jiter}, cache_on_host_{cache_on_host} {
XGProxyDMatrixCreate(&proxy_);
jni_status_ =
GlobalJvm()->GetEnv(reinterpret_cast<void **>(&jenv_), JNI_VERSION_1_6);
this->Reset();
dh::safe_cuda(cudaStreamCreateWithFlags(&copy_stream_, cudaStreamNonBlocking));
}
~DataIteratorProxy() { XGDMatrixFree(proxy_);
dh::safe_cuda(cudaStreamDestroy(copy_stream_));
}
DMatrixHandle GetDMatrixHandle() const { return proxy_; }
// Helper function for staging meta info.
void StageMetaInfo(Json json_interface) {
CHECK(!IsA<Null>(json_interface));
auto json_map = get<Object const>(json_interface);
if (json_map.find("label_str") == json_map.cend()) {
LOG(FATAL) << "Must have a label field.";
}
Json label = json_interface["label_str"];
CHECK(!IsA<Null>(label));
labels_.emplace_back(new dh::device_vector<float>);
CopyMetaInfo(&label, labels_.back().get(), copy_stream_);
label_interfaces_.emplace_back(label);
std::string str;
Json::Dump(label, &str);
XGDMatrixSetInfoFromInterface(proxy_, "label", str.c_str());
if (json_map.find("weight_str") != json_map.cend()) {
Json weight = json_interface["weight_str"];
CHECK(!IsA<Null>(weight));
weights_.emplace_back(new dh::device_vector<float>);
CopyMetaInfo(&weight, weights_.back().get(), copy_stream_);
weight_interfaces_.emplace_back(weight);
Json::Dump(weight, &str);
XGDMatrixSetInfoFromInterface(proxy_, "weight", str.c_str());
}
if (json_map.find("basemargin_str") != json_map.cend()) {
Json basemargin = json_interface["basemargin_str"];
base_margins_.emplace_back(new dh::device_vector<float>);
CopyMetaInfo(&basemargin, base_margins_.back().get(), copy_stream_);
margin_interfaces_.emplace_back(basemargin);
Json::Dump(basemargin, &str);
XGDMatrixSetInfoFromInterface(proxy_, "base_margin", str.c_str());
}
}
void CloseJvmBatch() {
if (last_batch_) {
jclass batch_class = CheckJvmCall(jenv_->GetObjectClass(last_batch_), jenv_);
jmethodID closeMethod = CheckJvmCall(jenv_->GetMethodID(batch_class, "close", "()V"), jenv_);
jenv_->CallVoidMethod(last_batch_, closeMethod);
last_batch_ = nullptr;
}
}
void Reset() {
it_ = 0;
this->CloseJvmBatch();
}
int32_t PullIterFromJVM() {
jclass iterClass = jenv_->FindClass("java/util/Iterator");
this->CloseJvmBatch();
jmethodID has_next =
CheckJvmCall(jenv_->GetMethodID(iterClass, "hasNext", "()Z"), jenv_);
jmethodID next = CheckJvmCall(
jenv_->GetMethodID(iterClass, "next", "()Ljava/lang/Object;"), jenv_);
if (jenv_->CallBooleanMethod(jiter_, has_next)) {
// batch should be ColumnBatch from jvm
jobject batch = CheckJvmCall(jenv_->CallObjectMethod(jiter_, next), jenv_);
jclass batch_class = CheckJvmCall(jenv_->GetObjectClass(batch), jenv_);
jmethodID getArrayInterfaceJson = CheckJvmCall(jenv_->GetMethodID(
batch_class, "getArrayInterfaceJson", "()Ljava/lang/String;"), jenv_);
auto jinterface =
static_cast<jstring>(jenv_->CallObjectMethod(batch, getArrayInterfaceJson));
CheckJvmCall(jinterface, jenv_);
char const *c_interface_str =
CheckJvmCall(jenv_->GetStringUTFChars(jinterface, nullptr), jenv_);
StageData(c_interface_str);
jenv_->ReleaseStringUTFChars(jinterface, c_interface_str);
last_batch_ = batch;
return 1;
} else {
return 0;
}
}
void StageData(std::string interface_str) {
++n_batches_;
// DataFrame
using T = decltype(host_columns_)::value_type::element_type;
host_columns_.emplace_back(std::unique_ptr<T>(new T));
// Stage the meta info.
auto json_interface =
Json::Load({interface_str.c_str(), interface_str.size()});
CHECK(!IsA<Null>(json_interface));
StageMetaInfo(json_interface);
Json features = json_interface["features_str"];
auto json_columns = get<Array const>(features);
std::vector<ArrayInterface> interfaces;
// Stage the data
for (auto &json_col : json_columns) {
auto column = ArrayInterface(get<Object const>(json_col));
interfaces.emplace_back(column);
}
Json::Dump(features, &interface_str);
CopyInterface(interfaces, json_columns, cudaMemcpyDeviceToHost,
&host_columns_.back()->data, &host_columns_.back()->valid,
&host_columns_.back()->interfaces, copy_stream_);
XGProxyDMatrixSetDataCudaColumnar(proxy_, interface_str.c_str());
it_++;
}
int NextFirstLoop() {
try {
dh::safe_cuda(cudaStreamSynchronize(copy_stream_));
if (this->PullIterFromJVM()) {
return 1;
} else {
initialized_ = true;
return 0;
}
} catch (dmlc::Error const &e) {
if (jni_status_ == JNI_EDETACHED) {
GlobalJvm()->DetachCurrentThread();
}
LOG(FATAL) << e.what();
}
LOG(FATAL) << "Unreachable";
return 1;
}
int NextSecondLoop() {
std::string str;
// Meta
auto const &label = this->label_interfaces_.at(it_);
Json::Dump(label, &str);
XGDMatrixSetInfoFromInterface(proxy_, "label", str.c_str());
if (n_batches_ == this->weight_interfaces_.size()) {
auto const &weight = this->weight_interfaces_.at(it_);
Json::Dump(weight, &str);
XGDMatrixSetInfoFromInterface(proxy_, "weight", str.c_str());
}
if (n_batches_ == this->margin_interfaces_.size()) {
auto const &base_margin = this->margin_interfaces_.at(it_);
Json::Dump(base_margin, &str);
XGDMatrixSetInfoFromInterface(proxy_, "base_margin", str.c_str());
}
// Data
auto const &json_interface = host_columns_.at(it_)->interfaces;
std::vector<ArrayInterface> in;
for (auto interface : json_interface) {
auto column = ArrayInterface(get<Object const>(interface));
in.emplace_back(column);
}
std::vector<Json> out;
CopyInterface(in, json_interface, cudaMemcpyHostToDevice, &staging_data_,
&staging_mask_, &out, nullptr);
Json temp{Array(std::move(out))};
std::string interface_str;
Json::Dump(temp, &interface_str);
XGProxyDMatrixSetDataCudaColumnar(proxy_, interface_str.c_str());
it_++;
return 1;
}
int Next() {
if (!initialized_) {
return NextFirstLoop();
} else {
if (it_ == n_batches_) {
return 0;
}
return NextSecondLoop();
}
};
};
namespace {
void Reset(DataIterHandle self) {
static_cast<xgboost::jni::DataIteratorProxy *>(self)->Reset();
}
int Next(DataIterHandle self) {
return static_cast<xgboost::jni::DataIteratorProxy *>(self)->Next();
}
} // anonymous namespace
XGB_DLL jint XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
jobject jiter,
jfloat jmissing,
jint jmax_bin, jint jnthread,
jlongArray jout) {
xgboost::jni::DataIteratorProxy proxy(jiter);
DMatrixHandle result;
auto ret = XGDeviceQuantileDMatrixCreateFromCallback(
&proxy, proxy.GetDMatrixHandle(), Reset, Next, jmissing, jnthread,
jmax_bin, &result);
setHandle(jenv, jout, result);
return ret;
}
} // namespace jni
} // namespace xgboost

View File

@ -1 +0,0 @@
../../xgboost4j/src/test

View File

@ -0,0 +1,127 @@
/*
Copyright (c) 2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.gpu.java;
import java.io.File;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import junit.framework.TestCase;
import org.junit.Test;
import ai.rapids.cudf.DType;
import ai.rapids.cudf.Schema;
import ai.rapids.cudf.Table;
import ai.rapids.cudf.ColumnVector;
import ai.rapids.cudf.CSVOptions;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.ColumnBatch;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.DeviceQuantileDMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
/**
* Tests the BoosterTest trained by DMatrix
* @throws XGBoostError
*/
public class BoosterTest {
@Test
public void testBooster() throws XGBoostError {
String trainingDataPath = "../../demo/data/veterans_lung_cancer.csv";
Schema schema = Schema.builder()
.column(DType.FLOAT32, "A")
.column(DType.FLOAT32, "B")
.column(DType.FLOAT32, "C")
.column(DType.FLOAT32, "D")
.column(DType.FLOAT32, "E")
.column(DType.FLOAT32, "F")
.column(DType.FLOAT32, "G")
.column(DType.FLOAT32, "H")
.column(DType.FLOAT32, "I")
.column(DType.FLOAT32, "J")
.column(DType.FLOAT32, "K")
.column(DType.FLOAT32, "L")
.column(DType.FLOAT32, "label")
.build();
CSVOptions opts = CSVOptions.builder()
.hasHeader().build();
int maxBin = 16;
int round = 100;
//set params
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("max_depth", 2);
put("objective", "binary:logistic");
put("num_round", round);
put("num_workers", 1);
put("tree_method", "gpu_hist");
put("predictor", "gpu_predictor");
put("max_bin", maxBin);
}
};
try (Table tmpTable = Table.readCSV(schema, opts, new File(trainingDataPath))) {
ColumnVector[] df = new ColumnVector[12];
for (int i = 0; i < 12; ++i) {
df[i] = tmpTable.getColumn(i);
}
try (Table X = new Table(df);) {
ColumnVector[] labels = new ColumnVector[1];
labels[0] = tmpTable.getColumn(12);
try (Table y = new Table(labels);) {
CudfColumnBatch batch = new CudfColumnBatch(X, y, null, null);
CudfColumn labelColumn = CudfColumn.from(tmpTable.getColumn(12));
//set watchList
HashMap<String, DMatrix> watches = new HashMap<>();
DMatrix dMatrix1 = new DMatrix(batch, Float.NaN, 1);
dMatrix1.setLabel(labelColumn);
watches.put("train", dMatrix1);
Booster model1 = XGBoost.train(dMatrix1, paramMap, round, watches, null, null);
List<ColumnBatch> tables = new LinkedList<>();
tables.add(batch);
DMatrix incrementalDMatrix = new DeviceQuantileDMatrix(tables.iterator(), Float.NaN, maxBin, 1);
//set watchList
HashMap<String, DMatrix> watches1 = new HashMap<>();
watches1.put("train", incrementalDMatrix);
Booster model2 = XGBoost.train(incrementalDMatrix, paramMap, round, watches1, null, null);
float[][] predicat1 = model1.predict(dMatrix1);
float[][] predicat2 = model2.predict(dMatrix1);
for (int i = 0; i < tmpTable.getRowCount(); i++) {
TestCase.assertTrue(predicat1[i][0] - predicat2[i][0] < 1e-6);
}
}
}
}
}
}

View File

@ -0,0 +1,123 @@
/*
Copyright (c) 2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.gpu.java;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import junit.framework.TestCase;
import com.google.common.primitives.Floats;
import org.apache.commons.lang.ArrayUtils;
import org.junit.Test;
import ai.rapids.cudf.Table;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.DeviceQuantileDMatrix;
import ml.dmlc.xgboost4j.java.ColumnBatch;
import ml.dmlc.xgboost4j.java.XGBoostError;
/**
* Test suite for DMatrix based on GPU
*/
public class DMatrixTest {
@Test
public void testCreateFromArrayInterfaceColumns() {
Float[] labelFloats = new Float[]{2f, 4f, 6f, 8f, 10f};
Throwable ex = null;
try (
Table X = new Table.TestBuilder().column(1.f, null, 5.f, 7.f, 9.f).build();
Table y = new Table.TestBuilder().column(labelFloats).build();
Table w = new Table.TestBuilder().column(labelFloats).build();
Table margin = new Table.TestBuilder().column(labelFloats).build();) {
CudfColumnBatch cudfDataFrame = new CudfColumnBatch(X, y, w, null);
CudfColumn labelColumn = CudfColumn.from(y.getColumn(0));
CudfColumn weightColumn = CudfColumn.from(w.getColumn(0));
CudfColumn baseMarginColumn = CudfColumn.from(margin.getColumn(0));
DMatrix dMatrix = new DMatrix(cudfDataFrame, 0, 1);
dMatrix.setLabel(labelColumn);
dMatrix.setWeight(weightColumn);
dMatrix.setBaseMargin(baseMarginColumn);
float[] anchor = convertFloatTofloat(labelFloats);
float[] label = dMatrix.getLabel();
float[] weight = dMatrix.getWeight();
float[] baseMargin = dMatrix.getBaseMargin();
TestCase.assertTrue(Arrays.equals(anchor, label));
TestCase.assertTrue(Arrays.equals(anchor, weight));
TestCase.assertTrue(Arrays.equals(anchor, baseMargin));
} catch (Throwable e) {
ex = e;
e.printStackTrace();
}
TestCase.assertNull(ex);
}
@Test
public void testCreateFromColumnDataIterator() throws XGBoostError {
Float[] label1 = {25f, 21f, 22f, 20f, 24f};
Float[] weight1 = {1.3f, 2.31f, 0.32f, 3.3f, 1.34f};
Float[] baseMargin1 = {1.2f, 0.2f, 1.3f, 2.4f, 3.5f};
Float[] label2 = {9f, 5f, 4f, 10f, 12f};
Float[] weight2 = {3.0f, 1.3f, 3.2f, 0.3f, 1.34f};
Float[] baseMargin2 = {0.2f, 2.5f, 3.1f, 4.4f, 2.2f};
try (
Table X_0 = new Table.TestBuilder()
.column(1.2f, null, 5.2f, 7.2f, 9.2f)
.column(0.2f, 0.4f, 0.6f, 2.6f, 0.10f)
.build();
Table y_0 = new Table.TestBuilder().column(label1).build();
Table w_0 = new Table.TestBuilder().column(weight1).build();
Table m_0 = new Table.TestBuilder().column(baseMargin1).build();
Table X_1 = new Table.TestBuilder().column(11.2f, 11.2f, 15.2f, 17.2f, 19.2f)
.column(1.2f, 1.4f, null, 12.6f, 10.10f).build();
Table y_1 = new Table.TestBuilder().column(label2).build();
Table w_1 = new Table.TestBuilder().column(weight2).build();
Table m_1 = new Table.TestBuilder().column(baseMargin2).build();) {
List<ColumnBatch> tables = new LinkedList<>();
tables.add(new CudfColumnBatch(X_0, y_0, w_0, m_0));
tables.add(new CudfColumnBatch(X_1, y_1, w_1, m_1));
DMatrix dmat = new DeviceQuantileDMatrix(tables.iterator(), 0.0f, 8, 1);
float[] anchorLabel = convertFloatTofloat((Float[]) ArrayUtils.addAll(label1, label2));
float[] anchorWeight = convertFloatTofloat((Float[]) ArrayUtils.addAll(weight1, weight2));
float[] anchorBaseMargin = convertFloatTofloat((Float[]) ArrayUtils.addAll(baseMargin1, baseMargin2));
TestCase.assertTrue(Arrays.equals(anchorLabel, dmat.getLabel()));
TestCase.assertTrue(Arrays.equals(anchorWeight, dmat.getWeight()));
TestCase.assertTrue(Arrays.equals(anchorBaseMargin, dmat.getBaseMargin()));
}
}
private float[] convertFloatTofloat(Float[] in) {
return Floats.toArray(Arrays.asList(in));
}
}

View File

@ -0,0 +1,40 @@
/*
Copyright (c) 2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java;
/**
* The abstracted XGBoost Column to get the cuda array interface which is used to
* set the information for DMatrix.
*
*/
public abstract class Column implements AutoCloseable {
/**
* Get the cuda array interface json string for the Column which can be representing
* weight, label, base margin column.
*
* This API will be called by
* {@link DMatrix#setLabel(Column)}
* {@link DMatrix#setWeight(Column)}
* {@link DMatrix#setBaseMargin(Column)}
*/
public abstract String getArrayInterfaceJson();
@Override
public void close() throws Exception {}
}

View File

@ -0,0 +1,93 @@
/*
Copyright (c) 2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java;
import java.util.Iterator;
/**
* The abstracted XGBoost ColumnBatch to get array interface from columnar data format.
* For example, the cuDF dataframe which employs apache arrow specification.
*/
public abstract class ColumnBatch implements AutoCloseable {
/**
* Get the cuda array interface json string for the whole ColumnBatch including
* the must-have feature, label columns and the optional weight, base margin columns.
*
* This function is be called by native code during iteration and can be made as private
* method. We keep it as public simply to silent the linter.
*/
public final String getArrayInterfaceJson() {
StringBuilder builder = new StringBuilder();
builder.append("{");
String featureStr = this.getFeatureArrayInterface();
if (featureStr == null || featureStr.isEmpty()) {
throw new RuntimeException("Feature array interface must not be empty");
} else {
builder.append("\"features_str\":" + featureStr);
}
String labelStr = this.getLabelsArrayInterface();
if (labelStr == null || labelStr.isEmpty()) {
throw new RuntimeException("Label array interface must not be empty");
} else {
builder.append(",\"label_str\":" + labelStr);
}
String weightStr = getWeightsArrayInterface();
if (weightStr != null && ! weightStr.isEmpty()) {
builder.append(",\"weight_str\":" + weightStr);
}
String baseMarginStr = getBaseMarginsArrayInterface();
if (baseMarginStr != null && ! baseMarginStr.isEmpty()) {
builder.append(",\"basemargin_str\":" + baseMarginStr);
}
builder.append("}");
return builder.toString();
}
/**
* Get the cuda array interface of the feature columns.
* The returned value must not be null or empty
*/
public abstract String getFeatureArrayInterface();
/**
* Get the cuda array interface of the label columns.
* The returned value must not be null or empty if we're creating
* {@link DeviceQuantileDMatrix#DeviceQuantileDMatrix(Iterator, float, int, int)}
*/
public abstract String getLabelsArrayInterface();
/**
* Get the cuda array interface of the weight columns.
* The returned value can be null or empty
*/
public abstract String getWeightsArrayInterface();
/**
* Get the cuda array interface of the base margin columns.
* The returned value can be null or empty
*/
public abstract String getBaseMarginsArrayInterface();
@Override
public void close() throws Exception {}
}

View File

@ -177,6 +177,64 @@ public class DMatrix {
this.handle = handle; this.handle = handle;
} }
/**
* Create the normal DMatrix from column array interface
* @param columnBatch the XGBoost ColumnBatch to provide the cuda array interface
* of feature columns
* @param missing missing value
* @param nthread threads number
* @throws XGBoostError
*/
public DMatrix(ColumnBatch columnBatch, float missing, int nthread) throws XGBoostError {
long[] out = new long[1];
String json = columnBatch.getFeatureArrayInterface();
if (json == null || json.isEmpty()) {
throw new XGBoostError("Expecting non-empty feature columns' array interface");
}
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromArrayInterfaceColumns(
json, missing, nthread, out));
handle = out[0];
}
/**
* Set label of DMatrix from cuda array interface
*
* @param column the XGBoost Column to provide the cuda array interface
* of label column
* @throws XGBoostError native error
*/
public void setLabel(Column column) throws XGBoostError {
setXGBDMatrixInfo("label", column.getArrayInterfaceJson());
}
/**
* Set weight of DMatrix from cuda array interface
*
* @param column the XGBoost Column to provide the cuda array interface
* of weight column
* @throws XGBoostError native error
*/
public void setWeight(Column column) throws XGBoostError {
setXGBDMatrixInfo("weight", column.getArrayInterfaceJson());
}
/**
* Set base margin of DMatrix from cuda array interface
*
* @param column the XGBoost Column to provide the cuda array interface
* of base margin column
* @throws XGBoostError native error
*/
public void setBaseMargin(Column column) throws XGBoostError {
setXGBDMatrixInfo("base_margin", column.getArrayInterfaceJson());
}
private void setXGBDMatrixInfo(String type, String json) throws XGBoostError {
if (json == null || json.isEmpty()) {
throw new XGBoostError("Empty " + type + " columns' array interface");
}
XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixSetInfoFromInterface(handle, type, json));
}
/** /**
* set label of dmatrix * set label of dmatrix

View File

@ -0,0 +1,68 @@
package ml.dmlc.xgboost4j.java;
import java.util.Iterator;
/**
* DeviceQuantileDMatrix will only be used to train
*/
public class DeviceQuantileDMatrix extends DMatrix {
/**
* Create DeviceQuantileDMatrix from iterator based on the cuda array interface
* @param iter the XGBoost ColumnBatch batch to provide the corresponding cuda array interface
* @param missing the missing value
* @param maxBin the max bin
* @param nthread the parallelism
* @throws XGBoostError
*/
public DeviceQuantileDMatrix(
Iterator<ColumnBatch> iter,
float missing,
int maxBin,
int nthread) throws XGBoostError {
super(0);
long[] out = new long[1];
XGBoostJNI.checkCall(XGBoostJNI.XGDeviceQuantileDMatrixCreateFromCallback(
iter, missing, maxBin, nthread, out));
handle = out[0];
}
@Override
public void setLabel(Column column) throws XGBoostError {
throw new XGBoostError("DeviceQuantileDMatrix does not support setLabel.");
}
@Override
public void setWeight(Column column) throws XGBoostError {
throw new XGBoostError("DeviceQuantileDMatrix does not support setWeight.");
}
@Override
public void setBaseMargin(Column column) throws XGBoostError {
throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.");
}
@Override
public void setLabel(float[] labels) throws XGBoostError {
throw new XGBoostError("DeviceQuantileDMatrix does not support setLabel.");
}
@Override
public void setWeight(float[] weights) throws XGBoostError {
throw new XGBoostError("DeviceQuantileDMatrix does not support setWeight.");
}
@Override
public void setBaseMargin(float[] baseMargin) throws XGBoostError {
throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.");
}
@Override
public void setBaseMargin(float[][] baseMargin) throws XGBoostError {
throw new XGBoostError("DeviceQuantileDMatrix does not support setBaseMargin.");
}
@Override
public void setGroup(int[] group) throws XGBoostError {
throw new XGBoostError("DeviceQuantileDMatrix does not support setGroup.");
}
}

View File

@ -134,4 +134,13 @@ class XGBoostJNI {
// This JNI function does not support the callback function for data preparation yet. // This JNI function does not support the callback function for data preparation yet.
final static native int RabitAllreduce(ByteBuffer sendrecvbuf, int count, final static native int RabitAllreduce(ByteBuffer sendrecvbuf, int count,
int enum_dtype, int enum_op); int enum_dtype, int enum_op);
public final static native int XGDMatrixSetInfoFromInterface(
long handle, String field, String json);
public final static native int XGDeviceQuantileDMatrixCreateFromCallback(
java.util.Iterator<ColumnBatch> iter, float missing, int nthread, int maxBin, long[] out);
public final static native int XGDMatrixCreateFromArrayInterfaceColumns(
String featureJson, float missing, int nthread, long[] out);
} }

View File

@ -19,6 +19,7 @@
#include <xgboost/c_api.h> #include <xgboost/c_api.h>
#include <xgboost/base.h> #include <xgboost/base.h>
#include <xgboost/logging.h> #include <xgboost/logging.h>
#include <xgboost/json.h>
#include "./xgboost4j.h" #include "./xgboost4j.h"
#include <cstring> #include <cstring>
#include <vector> #include <vector>
@ -43,12 +44,14 @@ void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle) {
jenv->SetLongArrayRegion(jhandle, 0, 1, &out); jenv->SetLongArrayRegion(jhandle, 0, 1, &out);
} }
// global JVM JavaVM*& GlobalJvm() {
static JavaVM* global_jvm = nullptr; static JavaVM* vm;
return vm;
}
// overrides JNI on load // overrides JNI on load
jint JNI_OnLoad(JavaVM *vm, void *reserved) { jint JNI_OnLoad(JavaVM *vm, void *reserved) {
global_jvm = vm; GlobalJvm() = vm;
return JNI_VERSION_1_6; return JNI_VERSION_1_6;
} }
@ -58,9 +61,9 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
DataHolderHandle set_function_handle) { DataHolderHandle set_function_handle) {
jobject jiter = static_cast<jobject>(data_handle); jobject jiter = static_cast<jobject>(data_handle);
JNIEnv* jenv; JNIEnv* jenv;
int jni_status = global_jvm->GetEnv((void **)&jenv, JNI_VERSION_1_6); int jni_status = GlobalJvm()->GetEnv((void **)&jenv, JNI_VERSION_1_6);
if (jni_status == JNI_EDETACHED) { if (jni_status == JNI_EDETACHED) {
global_jvm->AttachCurrentThread(reinterpret_cast<void **>(&jenv), nullptr); GlobalJvm()->AttachCurrentThread(reinterpret_cast<void **>(&jenv), nullptr);
} else { } else {
CHECK(jni_status == JNI_OK); CHECK(jni_status == JNI_OK);
} }
@ -148,13 +151,13 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext(
jenv->DeleteLocalRef(iterClass); jenv->DeleteLocalRef(iterClass);
// only detach if it is a async call. // only detach if it is a async call.
if (jni_status == JNI_EDETACHED) { if (jni_status == JNI_EDETACHED) {
global_jvm->DetachCurrentThread(); GlobalJvm()->DetachCurrentThread();
} }
return ret_value; return ret_value;
} catch(dmlc::Error const& e) { } catch(dmlc::Error const& e) {
// only detach if it is a async call. // only detach if it is a async call.
if (jni_status == JNI_EDETACHED) { if (jni_status == JNI_EDETACHED) {
global_jvm->DetachCurrentThread(); GlobalJvm()->DetachCurrentThread();
} }
LOG(FATAL) << e.what(); LOG(FATAL) << e.what();
return -1; return -1;
@ -968,3 +971,71 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
return 0; return 0;
} }
namespace xgboost {
namespace jni {
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
jobject jiter,
jfloat jmissing,
jint jmax_bin, jint jnthread,
jlongArray jout);
} // namespace jni
} // namespace xgboost
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDeviceQuantileDMatrixCreateFromCallback
* Signature: (Ljava/util/Iterator;FII[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDMatrixCreateFromCallback
(JNIEnv *jenv, jclass jcls, jobject jiter, jfloat jmissing, jint jmax_bin,
jint jnthread, jlongArray jout) {
return xgboost::jni::XGDeviceQuantileDMatrixCreateFromCallbackImpl(
jenv, jcls, jiter, jmissing, jmax_bin, jnthread, jout);
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixSetInfoFromInterface
* Signature: (JLjava/lang/String;Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFromInterface
(JNIEnv *jenv, jclass jcls, jlong jhandle, jstring jfield, jstring jjson_columns) {
DMatrixHandle handle = (DMatrixHandle) jhandle;
const char* field = jenv->GetStringUTFChars(jfield, 0);
const char* cjson_columns = jenv->GetStringUTFChars(jjson_columns, 0);
int ret = XGDMatrixSetInfoFromInterface(handle, field, cjson_columns);
JVM_CHECK_CALL(ret);
//release
if (field) jenv->ReleaseStringUTFChars(jfield, field);
if (cjson_columns) jenv->ReleaseStringUTFChars(jjson_columns, cjson_columns);
return ret;
}
/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGDMatrixCreateFromArrayInterfaceColumns
* Signature: (Ljava/lang/String;FI[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromArrayInterfaceColumns
(JNIEnv *jenv, jclass jcls, jstring jjson_columns, jfloat jmissing, jint jnthread, jlongArray jout) {
DMatrixHandle result;
const char* cjson_columns = jenv->GetStringUTFChars(jjson_columns, nullptr);
xgboost::Json config{xgboost::Object{}};
auto missing = static_cast<float>(jmissing);
auto n_threads = static_cast<int32_t>(jnthread);
config["missing"] = xgboost::Number(missing);
config["nthread"] = xgboost::Integer(n_threads);
std::string config_str;
xgboost::Json::Dump(config, &config_str);
int ret = XGDMatrixCreateFromCudaColumnar(cjson_columns, config_str.c_str(),
&result);
JVM_CHECK_CALL(ret);
if (cjson_columns) {
jenv->ReleaseStringUTFChars(jjson_columns, cjson_columns);
}
setHandle(jenv, jout, result);
return ret;
}

View File

@ -335,6 +335,30 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitVersionNumber
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_RabitAllreduce
(JNIEnv *, jclass, jobject, jint, jint, jint); (JNIEnv *, jclass, jobject, jint, jint, jint);
/*
* Class: ml_dmlc_xgboost4j_java_GpuXGBoostJNI
* Method: XGDMatrixSetInfoFromInterface
* Signature: (JLjava/lang/String;Ljava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixSetInfoFromInterface
(JNIEnv *, jclass, jlong, jstring, jstring);
/*
* Class: ml_dmlc_xgboost4j_java_GpuXGBoostJNI
* Method: XGDeviceQuantileDMatrixCreateFromCallback
* Signature: (Ljava/util/Iterator;FII[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDeviceQuantileDMatrixCreateFromCallback
(JNIEnv *, jclass, jobject, jfloat, jint, jint, jlongArray);
/*
* Class: ml_dmlc_xgboost4j_java_GpuXGBoostJNI
* Method: XGDMatrixCreateFromArrayInterfaceColumns
* Signature: (Ljava/lang/String;FI[J)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixCreateFromArrayInterfaceColumns
(JNIEnv *, jclass, jstring, jfloat, jint, jlongArray);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -22,7 +22,7 @@ RUN \
# NCCL2 (License: https://docs.nvidia.com/deeplearning/sdk/nccl-sla/index.html) # NCCL2 (License: https://docs.nvidia.com/deeplearning/sdk/nccl-sla/index.html)
RUN \ RUN \
export CUDA_SHORT=`echo $CUDA_VERSION_ARG | grep -o -E '[0-9]+\.[0-9]'` && \ export CUDA_SHORT=`echo $CUDA_VERSION_ARG | grep -o -E '[0-9]+\.[0-9]'` && \
export NCCL_VERSION=2.4.8-1 && \ export NCCL_VERSION=2.8.3-1 && \
wget -nv -nc https://developer.download.nvidia.com/compute/machine-learning/repos/rhel7/x86_64/nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm && \ wget -nv -nc https://developer.download.nvidia.com/compute/machine-learning/repos/rhel7/x86_64/nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm && \
rpm -i nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm && \ rpm -i nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm && \
yum -y update && \ yum -y update && \