[jvm-packages][xgboost4j-gpu] Support GPU dataframe and DeviceQuantileDMatrix (#7195)

Following classes are added to support dataframe in java binding:

- `Column` is an abstract type for a single column in tabular data.
- `ColumnBatch` is an abstract type for dataframe.

- `CuDFColumn` is an implementaiton of `Column` that consume cuDF column
- `CudfColumnBatch` is an implementation of `ColumnBatch` that consumes cuDF dataframe.

- `DeviceQuantileDMatrix` is the interface for quantized data.

The Java implementation mimics the Python interface and uses `__cuda_array_interface__` protocol for memory indexing.  One difference is on JVM package, the data batch is staged on the host as java iterators cannot be reset.

Co-authored-by: jiamingy <jm.yuan@outlook.com>
This commit is contained in:
Bobby Wang
2021-09-24 14:25:00 +08:00
committed by GitHub
parent d27a427dc5
commit 0ee11dac77
23 changed files with 1388 additions and 18 deletions

View File

@@ -12,7 +12,24 @@
<version>1.5.0-SNAPSHOT</version>
<packaging>jar</packaging>
<properties>
<cudf.version>21.08.2</cudf.version>
<cudf.classifier>cuda11</cudf.classifier>
</properties>
<dependencies>
<dependency>
<groupId>ai.rapids</groupId>
<artifactId>cudf</artifactId>
<version>${cudf.version}</version>
<classifier>${cudf.classifier}</classifier>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>2.10.5.1</version>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-hdfs</artifactId>

View File

@@ -1 +0,0 @@
../../../xgboost4j/src/main/java/

View File

@@ -0,0 +1,110 @@
/*
Copyright (c) 2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.gpu.java;
import ai.rapids.cudf.BaseDeviceMemoryBuffer;
import ai.rapids.cudf.BufferType;
import ai.rapids.cudf.ColumnVector;
import ai.rapids.cudf.DType;
import ml.dmlc.xgboost4j.java.Column;
/**
* This class is composing of base data with Apache Arrow format from Cudf ColumnVector.
* It will be used to generate the cuda array interface.
*/
class CudfColumn extends Column {
private final long dataPtr; // gpu data buffer address
private final long shape; // row count
private final long validPtr; // gpu valid buffer address
private final int typeSize; // type size in bytes
private final String typeStr; // follow array interface spec
private final long nullCount; // null count
private String arrayInterface = null; // the cuda array interface
public static CudfColumn from(ColumnVector cv) {
BaseDeviceMemoryBuffer dataBuffer = cv.getDeviceBufferFor(BufferType.DATA);
BaseDeviceMemoryBuffer validBuffer = cv.getDeviceBufferFor(BufferType.VALIDITY);
long validPtr = 0;
if (validBuffer != null) {
validPtr = validBuffer.getAddress();
}
DType dType = cv.getType();
String typeStr = "";
if (dType == DType.FLOAT32 || dType == DType.FLOAT64 ||
dType == DType.TIMESTAMP_DAYS || dType == DType.TIMESTAMP_MICROSECONDS ||
dType == DType.TIMESTAMP_MILLISECONDS || dType == DType.TIMESTAMP_NANOSECONDS ||
dType == DType.TIMESTAMP_SECONDS) {
typeStr = "<f" + dType.getSizeInBytes();
} else if (dType == DType.BOOL8 || dType == DType.INT8 || dType == DType.INT16 ||
dType == DType.INT32 || dType == DType.INT64) {
typeStr = "<i" + dType.getSizeInBytes();
} else {
// Unsupported type.
throw new IllegalArgumentException("Unsupported data type: " + dType);
}
return new CudfColumn(dataBuffer.getAddress(), cv.getRowCount(), validPtr,
dType.getSizeInBytes(), typeStr, cv.getNullCount());
}
private CudfColumn(long dataPtr, long shape, long validPtr, int typeSize, String typeStr,
long nullCount) {
this.dataPtr = dataPtr;
this.shape = shape;
this.validPtr = validPtr;
this.typeSize = typeSize;
this.typeStr = typeStr;
this.nullCount = nullCount;
}
@Override
public String getArrayInterfaceJson() {
// There is no race-condition
if (arrayInterface == null) {
arrayInterface = CudfUtils.buildArrayInterface(this);
}
return arrayInterface;
}
public long getDataPtr() {
return dataPtr;
}
public long getShape() {
return shape;
}
public long getValidPtr() {
return validPtr;
}
public int getTypeSize() {
return typeSize;
}
public String getTypeStr() {
return typeStr;
}
public long getNullCount() {
return nullCount;
}
}

View File

@@ -0,0 +1,88 @@
/*
Copyright (c) 2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.gpu.java;
import java.util.stream.IntStream;
import ai.rapids.cudf.Table;
import ml.dmlc.xgboost4j.java.ColumnBatch;
/**
* Class to wrap CUDF Table to generate the cuda array interface.
*/
public class CudfColumnBatch extends ColumnBatch {
private final Table feature;
private final Table label;
private final Table weight;
private final Table baseMargin;
public CudfColumnBatch(Table feature, Table labels, Table weights, Table baseMargins) {
this.feature = feature;
this.label = labels;
this.weight = weights;
this.baseMargin = baseMargins;
}
@Override
public String getFeatureArrayInterface() {
return getArrayInterface(this.feature);
}
@Override
public String getLabelsArrayInterface() {
return getArrayInterface(this.label);
}
@Override
public String getWeightsArrayInterface() {
return getArrayInterface(this.weight);
}
@Override
public String getBaseMarginsArrayInterface() {
return getArrayInterface(this.baseMargin);
}
@Override
public void close() {
if (feature != null) feature.close();
if (label != null) label.close();
if (weight != null) weight.close();
if (baseMargin != null) baseMargin.close();
}
private String getArrayInterface(Table table) {
if (table == null || table.getNumberOfColumns() == 0) {
return "";
}
return CudfUtils.buildArrayInterface(getAsCudfColumn(table));
}
private CudfColumn[] getAsCudfColumn(Table table) {
if (table == null || table.getNumberOfColumns() == 0) {
// This will never happen.
return new CudfColumn[]{};
}
return IntStream.range(0, table.getNumberOfColumns())
.mapToObj((i) -> table.getColumn(i))
.map(CudfColumn::from)
.toArray(CudfColumn[]::new);
}
}

View File

@@ -0,0 +1,100 @@
/*
Copyright (c) 2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.gpu.java;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
/**
* Cudf utilities to build cuda array interface against {@link CudfColumn}
*/
class CudfUtils {
/**
* Build the cuda array interface based on CudfColumn(s)
* @param cudfColumns the CudfColumn(s) to be built
* @return the json format of cuda array interface
*/
public static String buildArrayInterface(CudfColumn... cudfColumns) {
return new Builder().add(cudfColumns).build();
}
// Helper class to build array interface string
private static class Builder {
private JsonNodeFactory nodeFactory = new JsonNodeFactory(false);
private ArrayNode rootArrayNode = nodeFactory.arrayNode();
private Builder add(CudfColumn... columns) {
if (columns == null || columns.length <= 0) {
throw new IllegalArgumentException("At least one ColumnData is required.");
}
for (CudfColumn cd : columns) {
rootArrayNode.add(buildColumnObject(cd));
}
return this;
}
private String build() {
try {
ByteArrayOutputStream bos = new ByteArrayOutputStream();
JsonGenerator jsonGen = new JsonFactory().createGenerator(bos);
new ObjectMapper().writeTree(jsonGen, rootArrayNode);
return bos.toString();
} catch (IOException ie) {
ie.printStackTrace();
throw new RuntimeException("Failed to build array interface. Error: " + ie);
}
}
private ObjectNode buildColumnObject(CudfColumn column) {
if (column.getDataPtr() == 0) {
throw new IllegalArgumentException("Empty column data is NOT accepted!");
}
if (column.getTypeStr() == null || column.getTypeStr().isEmpty()) {
throw new IllegalArgumentException("Empty type string is NOT accepted!");
}
ObjectNode colDataObj = buildMetaObject(column.getDataPtr(), column.getShape(),
column.getTypeStr());
if (column.getValidPtr() != 0 && column.getNullCount() != 0) {
ObjectNode validObj = buildMetaObject(column.getValidPtr(), column.getShape(), "<t1");
colDataObj.set("mask", validObj);
}
return colDataObj;
}
private ObjectNode buildMetaObject(long ptr, long shape, final String typeStr) {
ObjectNode objNode = nodeFactory.objectNode();
ArrayNode shapeNode = objNode.putArray("shape");
shapeNode.add(shape);
ArrayNode dataNode = objNode.putArray("data");
dataNode.add(ptr)
.add(false);
objNode.put("typestr", typeStr)
.put("version", 1);
return objNode;
}
}
}

View File

@@ -0,0 +1 @@
../../../../../../../xgboost4j/src/main/java/ml/dmlc/xgboost4j/java

View File

@@ -1 +0,0 @@
../../xgboost4j/src/native

View File

@@ -0,0 +1,15 @@
#ifndef JVM_UTILS_H_
#define JVM_UTILS_H_
#define JVM_CHECK_CALL(__expr) \
{ \
int __errcode = (__expr); \
if (__errcode != 0) { \
return __errcode; \
} \
}
JavaVM*& GlobalJvm();
void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle);
#endif // JVM_UTILS_H_

View File

@@ -0,0 +1,25 @@
//
// Created by bobwang on 2021/9/8.
//
#ifndef XGBOOST_USE_CUDA
#include <jni.h>
#include "../../../../src/common/common.h"
#include "../../../../src/c_api/c_api_error.h"
namespace xgboost {
namespace jni {
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
jobject jiter,
jfloat jmissing,
jint jmax_bin, jint jnthread,
jlongArray jout) {
API_BEGIN();
common::AssertGPUSupport();
API_END();
}
} // namespace jni
} // namespace xgboost
#endif // XGBOOST_USE_CUDA

View File

@@ -0,0 +1,398 @@
#include <jni.h>
#include <thrust/system/cuda/experimental/pinned_allocator.h>
#include "../../../../src/common/device_helpers.cuh"
#include "../../../../src/data/array_interface.h"
#include "jvm_utils.h"
#include <xgboost/c_api.h>
namespace xgboost {
namespace jni {
template <typename T, typename Alloc>
T const *RawPtr(std::vector<T, Alloc> const &data) {
return data.data();
}
template <typename T, typename Alloc> T *RawPtr(std::vector<T, Alloc> &data) {
return data.data();
}
template <typename T> T const *RawPtr(dh::device_vector<T> const &data) {
return data.data().get();
}
template <typename T> T *RawPtr(dh::device_vector<T> &data) {
return data.data().get();
}
template <typename T> T CheckJvmCall(T const &v, JNIEnv *jenv) {
if (!v) {
CHECK(jenv->ExceptionOccurred());
jenv->ExceptionDescribe();
}
return v;
}
template <typename VCont>
void CopyColumnMask(xgboost::ArrayInterface const &interface,
std::vector<Json> const &columns, cudaMemcpyKind kind,
size_t c, VCont *p_mask, Json *p_out, cudaStream_t stream) {
auto &mask = *p_mask;
auto &out = *p_out;
auto size = sizeof(typename VCont::value_type) * interface.num_rows *
interface.num_cols;
mask.resize(size);
CHECK(RawPtr(mask));
CHECK(size);
CHECK(interface.valid.Data());
dh::safe_cuda(
cudaMemcpyAsync(RawPtr(mask), interface.valid.Data(), size, kind, stream));
auto const &mask_column = columns[c]["mask"];
out["mask"] = Object();
std::vector<Json> mask_data{
Json{reinterpret_cast<Integer::Int>(RawPtr(mask))},
Json{get<Boolean const>(mask_column["data"][1])}};
out["mask"]["data"] = Array(std::move(mask_data));
if (get<Array const>(mask_column["shape"]).size() == 2) {
std::vector<Json> mask_shape{
Json{get<Integer const>(mask_column["shape"][0])},
Json{get<Integer const>(mask_column["shape"][1])}};
out["mask"]["shape"] = Array(std::move(mask_shape));
} else if (get<Array const>(mask_column["shape"]).size() == 1) {
std::vector<Json> mask_shape{
Json{get<Integer const>(mask_column["shape"][0])}};
out["mask"]["shape"] = Array(std::move(mask_shape));
} else {
LOG(FATAL) << "Invalid shape of mask";
}
out["mask"]["typestr"] = String("<t1");
out["mask"]["version"] = Integer(1);
}
template <typename DCont, typename VCont>
void CopyInterface(std::vector<xgboost::ArrayInterface> &interface_arr,
std::vector<Json> const &columns, cudaMemcpyKind kind,
std::vector<DCont> *p_data, std::vector<VCont> *p_mask,
std::vector<xgboost::Json> *p_out, cudaStream_t stream) {
p_data->resize(interface_arr.size());
p_mask->resize(interface_arr.size());
p_out->resize(interface_arr.size());
for (size_t c = 0; c < interface_arr.size(); ++c) {
auto &interface = interface_arr.at(c);
size_t element_size = interface.ElementSize();
size_t size = element_size * interface.num_rows * interface.num_cols;
auto &data = (*p_data)[c];
auto &mask = (*p_mask)[c];
data.resize(size);
dh::safe_cuda(cudaMemcpyAsync(RawPtr(data), interface.data, size, kind, stream));
auto &out = (*p_out)[c];
out = Object();
std::vector<Json> j_data{
Json{Integer(reinterpret_cast<Integer::Int>(RawPtr(data)))},
Json{Boolean{false}}};
out["data"] = Array(std::move(j_data));
out["shape"] = Array(std::vector<Json>{Json(Integer(interface.num_rows)),
Json(Integer(interface.num_cols))});
if (interface.valid.Data()) {
CopyColumnMask(interface, columns, kind, c, &mask, &out, stream);
}
out["typestr"] = String("<f4");
out["version"] = Integer(1);
}
}
void CopyMetaInfo(Json *p_interface, dh::device_vector<float> *out, cudaStream_t stream) {
auto &j_interface = *p_interface;
CHECK_EQ(get<Array const>(j_interface).size(), 1);
auto object = get<Object>(get<Array>(j_interface)[0]);
ArrayInterface interface(object);
out->resize(interface.num_rows);
size_t element_size = interface.ElementSize();
size_t size = element_size * interface.num_rows;
dh::safe_cuda(cudaMemcpyAsync(RawPtr(*out), interface.data, size,
cudaMemcpyDeviceToDevice, stream));
j_interface[0]["data"][0] = reinterpret_cast<Integer::Int>(RawPtr(*out));
}
template <typename DCont, typename VCont> struct DataFrame {
std::vector<DCont> data;
std::vector<VCont> valid;
std::vector<Json> interfaces;
};
class DataIteratorProxy {
DMatrixHandle proxy_;
JNIEnv *jenv_;
int jni_status_;
jobject jiter_;
bool cache_on_host_{true}; // TODO(Bobby): Make this optional.
template <typename T>
using Alloc = thrust::system::cuda::experimental::pinned_allocator<T>;
template <typename U>
using HostVector = std::vector<U, Alloc<U>>;
// This vector is created for staging device data on host to save GPU memory.
// When space is not of concern, we can stage them on device memory directly.
std::vector<
std::unique_ptr<DataFrame<HostVector<char>, HostVector<std::uint8_t>>>>
host_columns_;
// TODO(Bobby): Use this instead of `host_columns_` if staging is not
// required.
std::vector<std::unique_ptr<DataFrame<dh::device_vector<char>,
dh::device_vector<std::uint8_t>>>>
device_columns_;
// Staging area for metainfo.
// TODO(Bobby): label_upper_bound, label_lower_bound, group.
std::vector<std::unique_ptr<dh::device_vector<float>>> labels_;
std::vector<std::unique_ptr<dh::device_vector<float>>> weights_;
std::vector<std::unique_ptr<dh::device_vector<float>>> base_margins_;
std::vector<Json> label_interfaces_;
std::vector<Json> weight_interfaces_;
std::vector<Json> margin_interfaces_;
size_t it_{0};
size_t n_batches_{0};
bool initialized_{false};
jobject last_batch_ {nullptr};
// Temp buffer on device, each `dh::device_vector` represents a column
// from cudf.
std::vector<dh::device_vector<char>> staging_data_;
std::vector<dh::device_vector<uint8_t>> staging_mask_;
cudaStream_t copy_stream_;
public:
explicit DataIteratorProxy(jobject jiter, bool cache_on_host = true)
: jiter_{jiter}, cache_on_host_{cache_on_host} {
XGProxyDMatrixCreate(&proxy_);
jni_status_ =
GlobalJvm()->GetEnv(reinterpret_cast<void **>(&jenv_), JNI_VERSION_1_6);
this->Reset();
dh::safe_cuda(cudaStreamCreateWithFlags(&copy_stream_, cudaStreamNonBlocking));
}
~DataIteratorProxy() { XGDMatrixFree(proxy_);
dh::safe_cuda(cudaStreamDestroy(copy_stream_));
}
DMatrixHandle GetDMatrixHandle() const { return proxy_; }
// Helper function for staging meta info.
void StageMetaInfo(Json json_interface) {
CHECK(!IsA<Null>(json_interface));
auto json_map = get<Object const>(json_interface);
if (json_map.find("label_str") == json_map.cend()) {
LOG(FATAL) << "Must have a label field.";
}
Json label = json_interface["label_str"];
CHECK(!IsA<Null>(label));
labels_.emplace_back(new dh::device_vector<float>);
CopyMetaInfo(&label, labels_.back().get(), copy_stream_);
label_interfaces_.emplace_back(label);
std::string str;
Json::Dump(label, &str);
XGDMatrixSetInfoFromInterface(proxy_, "label", str.c_str());
if (json_map.find("weight_str") != json_map.cend()) {
Json weight = json_interface["weight_str"];
CHECK(!IsA<Null>(weight));
weights_.emplace_back(new dh::device_vector<float>);
CopyMetaInfo(&weight, weights_.back().get(), copy_stream_);
weight_interfaces_.emplace_back(weight);
Json::Dump(weight, &str);
XGDMatrixSetInfoFromInterface(proxy_, "weight", str.c_str());
}
if (json_map.find("basemargin_str") != json_map.cend()) {
Json basemargin = json_interface["basemargin_str"];
base_margins_.emplace_back(new dh::device_vector<float>);
CopyMetaInfo(&basemargin, base_margins_.back().get(), copy_stream_);
margin_interfaces_.emplace_back(basemargin);
Json::Dump(basemargin, &str);
XGDMatrixSetInfoFromInterface(proxy_, "base_margin", str.c_str());
}
}
void CloseJvmBatch() {
if (last_batch_) {
jclass batch_class = CheckJvmCall(jenv_->GetObjectClass(last_batch_), jenv_);
jmethodID closeMethod = CheckJvmCall(jenv_->GetMethodID(batch_class, "close", "()V"), jenv_);
jenv_->CallVoidMethod(last_batch_, closeMethod);
last_batch_ = nullptr;
}
}
void Reset() {
it_ = 0;
this->CloseJvmBatch();
}
int32_t PullIterFromJVM() {
jclass iterClass = jenv_->FindClass("java/util/Iterator");
this->CloseJvmBatch();
jmethodID has_next =
CheckJvmCall(jenv_->GetMethodID(iterClass, "hasNext", "()Z"), jenv_);
jmethodID next = CheckJvmCall(
jenv_->GetMethodID(iterClass, "next", "()Ljava/lang/Object;"), jenv_);
if (jenv_->CallBooleanMethod(jiter_, has_next)) {
// batch should be ColumnBatch from jvm
jobject batch = CheckJvmCall(jenv_->CallObjectMethod(jiter_, next), jenv_);
jclass batch_class = CheckJvmCall(jenv_->GetObjectClass(batch), jenv_);
jmethodID getArrayInterfaceJson = CheckJvmCall(jenv_->GetMethodID(
batch_class, "getArrayInterfaceJson", "()Ljava/lang/String;"), jenv_);
auto jinterface =
static_cast<jstring>(jenv_->CallObjectMethod(batch, getArrayInterfaceJson));
CheckJvmCall(jinterface, jenv_);
char const *c_interface_str =
CheckJvmCall(jenv_->GetStringUTFChars(jinterface, nullptr), jenv_);
StageData(c_interface_str);
jenv_->ReleaseStringUTFChars(jinterface, c_interface_str);
last_batch_ = batch;
return 1;
} else {
return 0;
}
}
void StageData(std::string interface_str) {
++n_batches_;
// DataFrame
using T = decltype(host_columns_)::value_type::element_type;
host_columns_.emplace_back(std::unique_ptr<T>(new T));
// Stage the meta info.
auto json_interface =
Json::Load({interface_str.c_str(), interface_str.size()});
CHECK(!IsA<Null>(json_interface));
StageMetaInfo(json_interface);
Json features = json_interface["features_str"];
auto json_columns = get<Array const>(features);
std::vector<ArrayInterface> interfaces;
// Stage the data
for (auto &json_col : json_columns) {
auto column = ArrayInterface(get<Object const>(json_col));
interfaces.emplace_back(column);
}
Json::Dump(features, &interface_str);
CopyInterface(interfaces, json_columns, cudaMemcpyDeviceToHost,
&host_columns_.back()->data, &host_columns_.back()->valid,
&host_columns_.back()->interfaces, copy_stream_);
XGProxyDMatrixSetDataCudaColumnar(proxy_, interface_str.c_str());
it_++;
}
int NextFirstLoop() {
try {
dh::safe_cuda(cudaStreamSynchronize(copy_stream_));
if (this->PullIterFromJVM()) {
return 1;
} else {
initialized_ = true;
return 0;
}
} catch (dmlc::Error const &e) {
if (jni_status_ == JNI_EDETACHED) {
GlobalJvm()->DetachCurrentThread();
}
LOG(FATAL) << e.what();
}
LOG(FATAL) << "Unreachable";
return 1;
}
int NextSecondLoop() {
std::string str;
// Meta
auto const &label = this->label_interfaces_.at(it_);
Json::Dump(label, &str);
XGDMatrixSetInfoFromInterface(proxy_, "label", str.c_str());
if (n_batches_ == this->weight_interfaces_.size()) {
auto const &weight = this->weight_interfaces_.at(it_);
Json::Dump(weight, &str);
XGDMatrixSetInfoFromInterface(proxy_, "weight", str.c_str());
}
if (n_batches_ == this->margin_interfaces_.size()) {
auto const &base_margin = this->margin_interfaces_.at(it_);
Json::Dump(base_margin, &str);
XGDMatrixSetInfoFromInterface(proxy_, "base_margin", str.c_str());
}
// Data
auto const &json_interface = host_columns_.at(it_)->interfaces;
std::vector<ArrayInterface> in;
for (auto interface : json_interface) {
auto column = ArrayInterface(get<Object const>(interface));
in.emplace_back(column);
}
std::vector<Json> out;
CopyInterface(in, json_interface, cudaMemcpyHostToDevice, &staging_data_,
&staging_mask_, &out, nullptr);
Json temp{Array(std::move(out))};
std::string interface_str;
Json::Dump(temp, &interface_str);
XGProxyDMatrixSetDataCudaColumnar(proxy_, interface_str.c_str());
it_++;
return 1;
}
int Next() {
if (!initialized_) {
return NextFirstLoop();
} else {
if (it_ == n_batches_) {
return 0;
}
return NextSecondLoop();
}
};
};
namespace {
void Reset(DataIterHandle self) {
static_cast<xgboost::jni::DataIteratorProxy *>(self)->Reset();
}
int Next(DataIterHandle self) {
return static_cast<xgboost::jni::DataIteratorProxy *>(self)->Next();
}
} // anonymous namespace
XGB_DLL jint XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
jobject jiter,
jfloat jmissing,
jint jmax_bin, jint jnthread,
jlongArray jout) {
xgboost::jni::DataIteratorProxy proxy(jiter);
DMatrixHandle result;
auto ret = XGDeviceQuantileDMatrixCreateFromCallback(
&proxy, proxy.GetDMatrixHandle(), Reset, Next, jmissing, jnthread,
jmax_bin, &result);
setHandle(jenv, jout, result);
return ret;
}
} // namespace jni
} // namespace xgboost

View File

@@ -1 +0,0 @@
../../xgboost4j/src/test

View File

@@ -0,0 +1,127 @@
/*
Copyright (c) 2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.gpu.java;
import java.io.File;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import junit.framework.TestCase;
import org.junit.Test;
import ai.rapids.cudf.DType;
import ai.rapids.cudf.Schema;
import ai.rapids.cudf.Table;
import ai.rapids.cudf.ColumnVector;
import ai.rapids.cudf.CSVOptions;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.ColumnBatch;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.DeviceQuantileDMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
/**
* Tests the BoosterTest trained by DMatrix
* @throws XGBoostError
*/
public class BoosterTest {
@Test
public void testBooster() throws XGBoostError {
String trainingDataPath = "../../demo/data/veterans_lung_cancer.csv";
Schema schema = Schema.builder()
.column(DType.FLOAT32, "A")
.column(DType.FLOAT32, "B")
.column(DType.FLOAT32, "C")
.column(DType.FLOAT32, "D")
.column(DType.FLOAT32, "E")
.column(DType.FLOAT32, "F")
.column(DType.FLOAT32, "G")
.column(DType.FLOAT32, "H")
.column(DType.FLOAT32, "I")
.column(DType.FLOAT32, "J")
.column(DType.FLOAT32, "K")
.column(DType.FLOAT32, "L")
.column(DType.FLOAT32, "label")
.build();
CSVOptions opts = CSVOptions.builder()
.hasHeader().build();
int maxBin = 16;
int round = 100;
//set params
Map<String, Object> paramMap = new HashMap<String, Object>() {
{
put("max_depth", 2);
put("objective", "binary:logistic");
put("num_round", round);
put("num_workers", 1);
put("tree_method", "gpu_hist");
put("predictor", "gpu_predictor");
put("max_bin", maxBin);
}
};
try (Table tmpTable = Table.readCSV(schema, opts, new File(trainingDataPath))) {
ColumnVector[] df = new ColumnVector[12];
for (int i = 0; i < 12; ++i) {
df[i] = tmpTable.getColumn(i);
}
try (Table X = new Table(df);) {
ColumnVector[] labels = new ColumnVector[1];
labels[0] = tmpTable.getColumn(12);
try (Table y = new Table(labels);) {
CudfColumnBatch batch = new CudfColumnBatch(X, y, null, null);
CudfColumn labelColumn = CudfColumn.from(tmpTable.getColumn(12));
//set watchList
HashMap<String, DMatrix> watches = new HashMap<>();
DMatrix dMatrix1 = new DMatrix(batch, Float.NaN, 1);
dMatrix1.setLabel(labelColumn);
watches.put("train", dMatrix1);
Booster model1 = XGBoost.train(dMatrix1, paramMap, round, watches, null, null);
List<ColumnBatch> tables = new LinkedList<>();
tables.add(batch);
DMatrix incrementalDMatrix = new DeviceQuantileDMatrix(tables.iterator(), Float.NaN, maxBin, 1);
//set watchList
HashMap<String, DMatrix> watches1 = new HashMap<>();
watches1.put("train", incrementalDMatrix);
Booster model2 = XGBoost.train(incrementalDMatrix, paramMap, round, watches1, null, null);
float[][] predicat1 = model1.predict(dMatrix1);
float[][] predicat2 = model2.predict(dMatrix1);
for (int i = 0; i < tmpTable.getRowCount(); i++) {
TestCase.assertTrue(predicat1[i][0] - predicat2[i][0] < 1e-6);
}
}
}
}
}
}

View File

@@ -0,0 +1,123 @@
/*
Copyright (c) 2021 by Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.gpu.java;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import junit.framework.TestCase;
import com.google.common.primitives.Floats;
import org.apache.commons.lang.ArrayUtils;
import org.junit.Test;
import ai.rapids.cudf.Table;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.DeviceQuantileDMatrix;
import ml.dmlc.xgboost4j.java.ColumnBatch;
import ml.dmlc.xgboost4j.java.XGBoostError;
/**
* Test suite for DMatrix based on GPU
*/
public class DMatrixTest {
@Test
public void testCreateFromArrayInterfaceColumns() {
Float[] labelFloats = new Float[]{2f, 4f, 6f, 8f, 10f};
Throwable ex = null;
try (
Table X = new Table.TestBuilder().column(1.f, null, 5.f, 7.f, 9.f).build();
Table y = new Table.TestBuilder().column(labelFloats).build();
Table w = new Table.TestBuilder().column(labelFloats).build();
Table margin = new Table.TestBuilder().column(labelFloats).build();) {
CudfColumnBatch cudfDataFrame = new CudfColumnBatch(X, y, w, null);
CudfColumn labelColumn = CudfColumn.from(y.getColumn(0));
CudfColumn weightColumn = CudfColumn.from(w.getColumn(0));
CudfColumn baseMarginColumn = CudfColumn.from(margin.getColumn(0));
DMatrix dMatrix = new DMatrix(cudfDataFrame, 0, 1);
dMatrix.setLabel(labelColumn);
dMatrix.setWeight(weightColumn);
dMatrix.setBaseMargin(baseMarginColumn);
float[] anchor = convertFloatTofloat(labelFloats);
float[] label = dMatrix.getLabel();
float[] weight = dMatrix.getWeight();
float[] baseMargin = dMatrix.getBaseMargin();
TestCase.assertTrue(Arrays.equals(anchor, label));
TestCase.assertTrue(Arrays.equals(anchor, weight));
TestCase.assertTrue(Arrays.equals(anchor, baseMargin));
} catch (Throwable e) {
ex = e;
e.printStackTrace();
}
TestCase.assertNull(ex);
}
@Test
public void testCreateFromColumnDataIterator() throws XGBoostError {
Float[] label1 = {25f, 21f, 22f, 20f, 24f};
Float[] weight1 = {1.3f, 2.31f, 0.32f, 3.3f, 1.34f};
Float[] baseMargin1 = {1.2f, 0.2f, 1.3f, 2.4f, 3.5f};
Float[] label2 = {9f, 5f, 4f, 10f, 12f};
Float[] weight2 = {3.0f, 1.3f, 3.2f, 0.3f, 1.34f};
Float[] baseMargin2 = {0.2f, 2.5f, 3.1f, 4.4f, 2.2f};
try (
Table X_0 = new Table.TestBuilder()
.column(1.2f, null, 5.2f, 7.2f, 9.2f)
.column(0.2f, 0.4f, 0.6f, 2.6f, 0.10f)
.build();
Table y_0 = new Table.TestBuilder().column(label1).build();
Table w_0 = new Table.TestBuilder().column(weight1).build();
Table m_0 = new Table.TestBuilder().column(baseMargin1).build();
Table X_1 = new Table.TestBuilder().column(11.2f, 11.2f, 15.2f, 17.2f, 19.2f)
.column(1.2f, 1.4f, null, 12.6f, 10.10f).build();
Table y_1 = new Table.TestBuilder().column(label2).build();
Table w_1 = new Table.TestBuilder().column(weight2).build();
Table m_1 = new Table.TestBuilder().column(baseMargin2).build();) {
List<ColumnBatch> tables = new LinkedList<>();
tables.add(new CudfColumnBatch(X_0, y_0, w_0, m_0));
tables.add(new CudfColumnBatch(X_1, y_1, w_1, m_1));
DMatrix dmat = new DeviceQuantileDMatrix(tables.iterator(), 0.0f, 8, 1);
float[] anchorLabel = convertFloatTofloat((Float[]) ArrayUtils.addAll(label1, label2));
float[] anchorWeight = convertFloatTofloat((Float[]) ArrayUtils.addAll(weight1, weight2));
float[] anchorBaseMargin = convertFloatTofloat((Float[]) ArrayUtils.addAll(baseMargin1, baseMargin2));
TestCase.assertTrue(Arrays.equals(anchorLabel, dmat.getLabel()));
TestCase.assertTrue(Arrays.equals(anchorWeight, dmat.getWeight()));
TestCase.assertTrue(Arrays.equals(anchorBaseMargin, dmat.getBaseMargin()));
}
}
private float[] convertFloatTofloat(Float[] in) {
return Floats.toArray(Arrays.asList(in));
}
}