[jvm-packages] remove xgboost4j-gpu and rework cudf column (#10630)
This commit is contained in:
@@ -96,6 +96,8 @@
|
||||
<argument>create_jni.py</argument>
|
||||
<argument>--log-capi-invocation</argument>
|
||||
<argument>${log.capi.invocation}</argument>
|
||||
<argument>--use-cuda</argument>
|
||||
<argument>${use.cuda}</argument>
|
||||
</arguments>
|
||||
<workingDirectory>${user.dir}</workingDirectory>
|
||||
<skip>${skip.native.build}</skip>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
Copyright (c) 2021-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -17,24 +17,17 @@
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
/**
|
||||
* The abstracted XGBoost Column to get the cuda array interface which is used to
|
||||
* set the information for DMatrix.
|
||||
*
|
||||
* This Column abstraction provides an array interface JSON string, which is
|
||||
* used to reconstruct columnar data within the XGBoost library.
|
||||
*/
|
||||
public abstract class Column implements AutoCloseable {
|
||||
|
||||
/**
|
||||
* Get the cuda array interface json string for the Column which can be representing
|
||||
* weight, label, base margin column.
|
||||
*
|
||||
* This API will be called by
|
||||
* {@link DMatrix#setLabel(Column)}
|
||||
* {@link DMatrix#setWeight(Column)}
|
||||
* {@link DMatrix#setBaseMargin(Column)}
|
||||
* Return array interface json string for this Column
|
||||
*/
|
||||
public abstract String getArrayInterfaceJson();
|
||||
public abstract String toJson();
|
||||
|
||||
@Override
|
||||
public void close() throws Exception {}
|
||||
|
||||
public void close() throws Exception {
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
Copyright (c) 2021 by Contributors
|
||||
Copyright (c) 2021-2024 by Contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@@ -16,78 +16,12 @@
|
||||
|
||||
package ml.dmlc.xgboost4j.java;
|
||||
|
||||
import java.util.Iterator;
|
||||
|
||||
/**
|
||||
* The abstracted XGBoost ColumnBatch to get array interface from columnar data format.
|
||||
* For example, the cuDF dataframe which employs apache arrow specification.
|
||||
* This class wraps multiple Column and provides the array interface json
|
||||
* for all columns.
|
||||
*/
|
||||
public abstract class ColumnBatch implements AutoCloseable {
|
||||
/**
|
||||
* Get the cuda array interface json string for the whole ColumnBatch including
|
||||
* the must-have feature, label columns and the optional weight, base margin columns.
|
||||
*
|
||||
* This function is be called by native code during iteration and can be made as private
|
||||
* method. We keep it as public simply to silent the linter.
|
||||
*/
|
||||
public final String getArrayInterfaceJson() {
|
||||
|
||||
StringBuilder builder = new StringBuilder();
|
||||
builder.append("{");
|
||||
String featureStr = this.getFeatureArrayInterface();
|
||||
if (featureStr == null || featureStr.isEmpty()) {
|
||||
throw new RuntimeException("Feature array interface must not be empty");
|
||||
} else {
|
||||
builder.append("\"features_str\":" + featureStr);
|
||||
}
|
||||
|
||||
String labelStr = this.getLabelsArrayInterface();
|
||||
if (labelStr == null || labelStr.isEmpty()) {
|
||||
throw new RuntimeException("Label array interface must not be empty");
|
||||
} else {
|
||||
builder.append(",\"label_str\":" + labelStr);
|
||||
}
|
||||
|
||||
String weightStr = getWeightsArrayInterface();
|
||||
if (weightStr != null && ! weightStr.isEmpty()) {
|
||||
builder.append(",\"weight_str\":" + weightStr);
|
||||
}
|
||||
|
||||
String baseMarginStr = getBaseMarginsArrayInterface();
|
||||
if (baseMarginStr != null && ! baseMarginStr.isEmpty()) {
|
||||
builder.append(",\"basemargin_str\":" + baseMarginStr);
|
||||
}
|
||||
|
||||
builder.append("}");
|
||||
return builder.toString();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the cuda array interface of the feature columns.
|
||||
* The returned value must not be null or empty
|
||||
*/
|
||||
public abstract String getFeatureArrayInterface();
|
||||
|
||||
/**
|
||||
* Get the cuda array interface of the label columns.
|
||||
* The returned value must not be null or empty if we're creating
|
||||
* {@link QuantileDMatrix#QuantileDMatrix(Iterator, float, int, int)}
|
||||
*/
|
||||
public abstract String getLabelsArrayInterface();
|
||||
|
||||
/**
|
||||
* Get the cuda array interface of the weight columns.
|
||||
* The returned value can be null or empty
|
||||
*/
|
||||
public abstract String getWeightsArrayInterface();
|
||||
|
||||
/**
|
||||
* Get the cuda array interface of the base margin columns.
|
||||
* The returned value can be null or empty
|
||||
*/
|
||||
public abstract String getBaseMarginsArrayInterface();
|
||||
|
||||
@Override
|
||||
public void close() throws Exception {}
|
||||
public abstract class ColumnBatch extends Column {
|
||||
|
||||
/** Get features cuda array interface json string */
|
||||
public abstract String toFeaturesJson();
|
||||
}
|
||||
|
||||
@@ -195,7 +195,7 @@ public class DMatrix {
|
||||
*/
|
||||
public DMatrix(ColumnBatch columnBatch, float missing, int nthread) throws XGBoostError {
|
||||
long[] out = new long[1];
|
||||
String json = columnBatch.getFeatureArrayInterface();
|
||||
String json = columnBatch.toFeaturesJson();
|
||||
if (json == null || json.isEmpty()) {
|
||||
throw new XGBoostError("Expecting non-empty feature columns' array interface");
|
||||
}
|
||||
@@ -228,7 +228,7 @@ public class DMatrix {
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setQueryId(Column column) throws XGBoostError {
|
||||
setXGBDMatrixInfo("qid", column.getArrayInterfaceJson());
|
||||
setXGBDMatrixInfo("qid", column.toJson());
|
||||
}
|
||||
|
||||
private void setXGBDMatrixInfo(String type, String json) throws XGBoostError {
|
||||
@@ -362,7 +362,7 @@ public class DMatrix {
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setLabel(Column column) throws XGBoostError {
|
||||
setXGBDMatrixInfo("label", column.getArrayInterfaceJson());
|
||||
setXGBDMatrixInfo("label", column.toJson());
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -393,7 +393,7 @@ public class DMatrix {
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setWeight(Column column) throws XGBoostError {
|
||||
setXGBDMatrixInfo("weight", column.getArrayInterfaceJson());
|
||||
setXGBDMatrixInfo("weight", column.toJson());
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -421,7 +421,7 @@ public class DMatrix {
|
||||
* @throws XGBoostError native error
|
||||
*/
|
||||
public void setBaseMargin(Column column) throws XGBoostError {
|
||||
setXGBDMatrixInfo("base_margin", column.getArrayInterfaceJson());
|
||||
setXGBDMatrixInfo("base_margin", column.toJson());
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
15
jvm-packages/xgboost4j/src/native/jvm_utils.h
Normal file
15
jvm-packages/xgboost4j/src/native/jvm_utils.h
Normal file
@@ -0,0 +1,15 @@
|
||||
#ifndef JVM_UTILS_H_
|
||||
#define JVM_UTILS_H_
|
||||
|
||||
#define JVM_CHECK_CALL(__expr) \
|
||||
{ \
|
||||
int __errcode = (__expr); \
|
||||
if (__errcode != 0) { \
|
||||
return __errcode; \
|
||||
} \
|
||||
}
|
||||
|
||||
JavaVM*& GlobalJvm();
|
||||
void setHandle(JNIEnv *jenv, jlongArray jhandle, void* handle);
|
||||
|
||||
#endif // JVM_UTILS_H_
|
||||
32
jvm-packages/xgboost4j/src/native/xgboost4j-gpu.cpp
Normal file
32
jvm-packages/xgboost4j/src/native/xgboost4j-gpu.cpp
Normal file
@@ -0,0 +1,32 @@
|
||||
//
|
||||
// Created by bobwang on 2021/9/8.
|
||||
//
|
||||
|
||||
#ifndef XGBOOST_USE_CUDA
|
||||
|
||||
#include <jni.h>
|
||||
|
||||
#include "../../../../src/common/common.h"
|
||||
#include "../../../../src/c_api/c_api_error.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace jni {
|
||||
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
|
||||
jobject jiter,
|
||||
jfloat jmissing,
|
||||
jint jmax_bin, jint jnthread,
|
||||
jlongArray jout) {
|
||||
API_BEGIN();
|
||||
common::AssertGPUSupport();
|
||||
API_END();
|
||||
}
|
||||
XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
|
||||
jobject jdata_iter, jobject jref_iter,
|
||||
char const *config, jlongArray jout) {
|
||||
API_BEGIN();
|
||||
common::AssertGPUSupport();
|
||||
API_END();
|
||||
}
|
||||
} // namespace jni
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
431
jvm-packages/xgboost4j/src/native/xgboost4j-gpu.cu
Normal file
431
jvm-packages/xgboost4j/src/native/xgboost4j-gpu.cu
Normal file
@@ -0,0 +1,431 @@
|
||||
#include <jni.h>
|
||||
|
||||
#include "../../../../src/common/device_helpers.cuh"
|
||||
#include "../../../../src/common/cuda_pinned_allocator.h"
|
||||
#include "../../../../src/data/array_interface.h"
|
||||
#include "jvm_utils.h"
|
||||
#include <xgboost/c_api.h>
|
||||
|
||||
namespace xgboost {
|
||||
namespace jni {
|
||||
|
||||
template <typename T, typename Alloc>
|
||||
T const *RawPtr(std::vector<T, Alloc> const &data) {
|
||||
return data.data();
|
||||
}
|
||||
|
||||
template <typename T, typename Alloc> T *RawPtr(std::vector<T, Alloc> &data) {
|
||||
return data.data();
|
||||
}
|
||||
|
||||
template <typename T> T const *RawPtr(dh::device_vector<T> const &data) {
|
||||
return data.data().get();
|
||||
}
|
||||
|
||||
template <typename T> T *RawPtr(dh::device_vector<T> &data) {
|
||||
return data.data().get();
|
||||
}
|
||||
|
||||
template <typename T> T CheckJvmCall(T const &v, JNIEnv *jenv) {
|
||||
if (!v) {
|
||||
CHECK(jenv->ExceptionOccurred());
|
||||
jenv->ExceptionDescribe();
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
template <typename VCont>
|
||||
void CopyColumnMask(xgboost::ArrayInterface<1> const &interface,
|
||||
std::vector<Json> const &columns, cudaMemcpyKind kind,
|
||||
size_t c, VCont *p_mask, Json *p_out, cudaStream_t stream) {
|
||||
auto &mask = *p_mask;
|
||||
auto &out = *p_out;
|
||||
auto size = sizeof(typename VCont::value_type) * interface.n;
|
||||
mask.resize(size);
|
||||
CHECK(RawPtr(mask));
|
||||
CHECK(size);
|
||||
CHECK(interface.valid.Data());
|
||||
dh::safe_cuda(
|
||||
cudaMemcpyAsync(RawPtr(mask), interface.valid.Data(), size, kind, stream));
|
||||
auto const &mask_column = columns[c]["mask"];
|
||||
out["mask"] = Object();
|
||||
std::vector<Json> mask_data{
|
||||
Json{reinterpret_cast<Integer::Int>(RawPtr(mask))},
|
||||
Json{get<Boolean const>(mask_column["data"][1])}};
|
||||
out["mask"]["data"] = Array(std::move(mask_data));
|
||||
if (get<Array const>(mask_column["shape"]).size() == 2) {
|
||||
std::vector<Json> mask_shape{
|
||||
Json{get<Integer const>(mask_column["shape"][0])},
|
||||
Json{get<Integer const>(mask_column["shape"][1])}};
|
||||
out["mask"]["shape"] = Array(std::move(mask_shape));
|
||||
} else if (get<Array const>(mask_column["shape"]).size() == 1) {
|
||||
std::vector<Json> mask_shape{
|
||||
Json{get<Integer const>(mask_column["shape"][0])}};
|
||||
out["mask"]["shape"] = Array(std::move(mask_shape));
|
||||
} else {
|
||||
LOG(FATAL) << "Invalid shape of mask";
|
||||
}
|
||||
out["mask"]["typestr"] = String("<t1");
|
||||
out["mask"]["version"] = Integer(3);
|
||||
}
|
||||
|
||||
template <typename DCont, typename VCont>
|
||||
void CopyInterface(std::vector<xgboost::ArrayInterface<1>> &interface_arr,
|
||||
std::vector<Json> const &columns, cudaMemcpyKind kind,
|
||||
std::vector<DCont> *p_data, std::vector<VCont> *p_mask,
|
||||
std::vector<xgboost::Json> *p_out, cudaStream_t stream) {
|
||||
p_data->resize(interface_arr.size());
|
||||
p_mask->resize(interface_arr.size());
|
||||
p_out->resize(interface_arr.size());
|
||||
for (size_t c = 0; c < interface_arr.size(); ++c) {
|
||||
auto &interface = interface_arr.at(c);
|
||||
size_t element_size = interface.ElementSize();
|
||||
size_t size = element_size * interface.n;
|
||||
|
||||
auto &data = (*p_data)[c];
|
||||
auto &mask = (*p_mask)[c];
|
||||
data.resize(size);
|
||||
dh::safe_cuda(cudaMemcpyAsync(RawPtr(data), interface.data, size, kind, stream));
|
||||
|
||||
auto &out = (*p_out)[c];
|
||||
out = Object();
|
||||
std::vector<Json> j_data{
|
||||
Json{Integer(reinterpret_cast<Integer::Int>(RawPtr(data)))},
|
||||
Json{Boolean{false}}};
|
||||
|
||||
out["data"] = Array(std::move(j_data));
|
||||
out["shape"] = Array(std::vector<Json>{Json(Integer(interface.Shape(0)))});
|
||||
|
||||
if (interface.valid.Data()) {
|
||||
CopyColumnMask(interface, columns, kind, c, &mask, &out, stream);
|
||||
}
|
||||
out["typestr"] = String("<f4");
|
||||
out["version"] = Integer(3);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CopyMetaInfo(Json *p_interface, dh::device_vector<T> *out, cudaStream_t stream) {
|
||||
auto &j_interface = *p_interface;
|
||||
CHECK_EQ(get<Array const>(j_interface).size(), 1);
|
||||
auto object = get<Object>(get<Array>(j_interface)[0]);
|
||||
ArrayInterface<1> interface(object);
|
||||
out->resize(interface.Shape(0));
|
||||
size_t element_size = interface.ElementSize();
|
||||
size_t size = element_size * interface.n;
|
||||
dh::safe_cuda(cudaMemcpyAsync(RawPtr(*out), interface.data, size,
|
||||
cudaMemcpyDeviceToDevice, stream));
|
||||
j_interface[0]["data"][0] = reinterpret_cast<Integer::Int>(RawPtr(*out));
|
||||
}
|
||||
|
||||
template <typename DCont, typename VCont> struct DataFrame {
|
||||
std::vector<DCont> data;
|
||||
std::vector<VCont> valid;
|
||||
std::vector<Json> interfaces;
|
||||
};
|
||||
|
||||
class DataIteratorProxy {
|
||||
DMatrixHandle proxy_;
|
||||
JNIEnv *jenv_;
|
||||
int jni_status_;
|
||||
jobject jiter_;
|
||||
bool cache_on_host_{true}; // TODO(Bobby): Make this optional.
|
||||
|
||||
template <typename T>
|
||||
using Alloc = xgboost::common::cuda::pinned_allocator<T>;
|
||||
template <typename U>
|
||||
using HostVector = std::vector<U, Alloc<U>>;
|
||||
|
||||
// This vector is created for staging device data on host to save GPU memory.
|
||||
// When space is not of concern, we can stage them on device memory directly.
|
||||
std::vector<
|
||||
std::unique_ptr<DataFrame<HostVector<char>, HostVector<std::uint8_t>>>>
|
||||
host_columns_;
|
||||
// TODO(Bobby): Use this instead of `host_columns_` if staging is not
|
||||
// required.
|
||||
std::vector<std::unique_ptr<DataFrame<dh::device_vector<char>,
|
||||
dh::device_vector<std::uint8_t>>>>
|
||||
device_columns_;
|
||||
|
||||
// Staging area for metainfo.
|
||||
// TODO(Bobby): label_upper_bound, label_lower_bound, group.
|
||||
std::vector<std::unique_ptr<dh::device_vector<float>>> labels_;
|
||||
std::vector<std::unique_ptr<dh::device_vector<float>>> weights_;
|
||||
std::vector<std::unique_ptr<dh::device_vector<float>>> base_margins_;
|
||||
std::vector<std::unique_ptr<dh::device_vector<int>>> qids_;
|
||||
std::vector<Json> label_interfaces_;
|
||||
std::vector<Json> weight_interfaces_;
|
||||
std::vector<Json> margin_interfaces_;
|
||||
std::vector<Json> qid_interfaces_;
|
||||
|
||||
size_t it_{0};
|
||||
size_t n_batches_{0};
|
||||
bool initialized_{false};
|
||||
jobject last_batch_ {nullptr};
|
||||
|
||||
// Temp buffer on device, each `dh::device_vector` represents a column
|
||||
// from cudf.
|
||||
std::vector<dh::device_vector<char>> staging_data_;
|
||||
std::vector<dh::device_vector<uint8_t>> staging_mask_;
|
||||
|
||||
cudaStream_t copy_stream_;
|
||||
|
||||
public:
|
||||
explicit DataIteratorProxy(jobject jiter, bool cache_on_host = true)
|
||||
: jiter_{jiter}, cache_on_host_{cache_on_host} {
|
||||
XGProxyDMatrixCreate(&proxy_);
|
||||
jni_status_ =
|
||||
GlobalJvm()->GetEnv(reinterpret_cast<void **>(&jenv_), JNI_VERSION_1_6);
|
||||
this->Reset();
|
||||
dh::safe_cuda(cudaStreamCreateWithFlags(©_stream_, cudaStreamNonBlocking));
|
||||
}
|
||||
~DataIteratorProxy() { XGDMatrixFree(proxy_);
|
||||
dh::safe_cuda(cudaStreamDestroy(copy_stream_));
|
||||
}
|
||||
|
||||
DMatrixHandle GetDMatrixHandle() const { return proxy_; }
|
||||
|
||||
// Helper function for staging meta info.
|
||||
void StageMetaInfo(Json json_interface) {
|
||||
CHECK(!IsA<Null>(json_interface));
|
||||
auto json_map = get<Object const>(json_interface);
|
||||
if (json_map.find("label") == json_map.cend()) {
|
||||
LOG(FATAL) << "Must have a label field.";
|
||||
}
|
||||
|
||||
Json label = json_interface["label"];
|
||||
CHECK(!IsA<Null>(label));
|
||||
labels_.emplace_back(new dh::device_vector<float>);
|
||||
CopyMetaInfo(&label, labels_.back().get(), copy_stream_);
|
||||
label_interfaces_.emplace_back(label);
|
||||
|
||||
std::string str;
|
||||
Json::Dump(label, &str);
|
||||
XGDMatrixSetInfoFromInterface(proxy_, "label", str.c_str());
|
||||
|
||||
if (json_map.find("weight") != json_map.cend()) {
|
||||
Json weight = json_interface["weight"];
|
||||
CHECK(!IsA<Null>(weight));
|
||||
weights_.emplace_back(new dh::device_vector<float>);
|
||||
CopyMetaInfo(&weight, weights_.back().get(), copy_stream_);
|
||||
weight_interfaces_.emplace_back(weight);
|
||||
|
||||
Json::Dump(weight, &str);
|
||||
XGDMatrixSetInfoFromInterface(proxy_, "weight", str.c_str());
|
||||
}
|
||||
|
||||
if (json_map.find("baseMargin") != json_map.cend()) {
|
||||
Json basemargin = json_interface["baseMargin"];
|
||||
base_margins_.emplace_back(new dh::device_vector<float>);
|
||||
CopyMetaInfo(&basemargin, base_margins_.back().get(), copy_stream_);
|
||||
margin_interfaces_.emplace_back(basemargin);
|
||||
|
||||
Json::Dump(basemargin, &str);
|
||||
XGDMatrixSetInfoFromInterface(proxy_, "base_margin", str.c_str());
|
||||
}
|
||||
|
||||
if (json_map.find("qid") != json_map.cend()) {
|
||||
Json qid = json_interface["qid"];
|
||||
qids_.emplace_back(new dh::device_vector<int>);
|
||||
CopyMetaInfo(&qid, qids_.back().get(), copy_stream_);
|
||||
qid_interfaces_.emplace_back(qid);
|
||||
|
||||
Json::Dump(qid, &str);
|
||||
XGDMatrixSetInfoFromInterface(proxy_, "qid", str.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
void CloseJvmBatch() {
|
||||
if (last_batch_) {
|
||||
jclass batch_class = CheckJvmCall(jenv_->GetObjectClass(last_batch_), jenv_);
|
||||
jmethodID closeMethod = CheckJvmCall(jenv_->GetMethodID(batch_class, "close", "()V"), jenv_);
|
||||
jenv_->CallVoidMethod(last_batch_, closeMethod);
|
||||
last_batch_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
it_ = 0;
|
||||
this->CloseJvmBatch();
|
||||
}
|
||||
|
||||
int32_t PullIterFromJVM() {
|
||||
jclass iterClass = jenv_->FindClass("java/util/Iterator");
|
||||
this->CloseJvmBatch();
|
||||
|
||||
jmethodID has_next =
|
||||
CheckJvmCall(jenv_->GetMethodID(iterClass, "hasNext", "()Z"), jenv_);
|
||||
jmethodID next = CheckJvmCall(
|
||||
jenv_->GetMethodID(iterClass, "next", "()Ljava/lang/Object;"), jenv_);
|
||||
|
||||
if (jenv_->CallBooleanMethod(jiter_, has_next)) {
|
||||
// batch should be ColumnBatch from jvm
|
||||
jobject batch = CheckJvmCall(jenv_->CallObjectMethod(jiter_, next), jenv_);
|
||||
jclass batch_class = CheckJvmCall(jenv_->GetObjectClass(batch), jenv_);
|
||||
jmethodID toJson = CheckJvmCall(jenv_->GetMethodID(
|
||||
batch_class, "toJson", "()Ljava/lang/String;"), jenv_);
|
||||
|
||||
auto jinterface =
|
||||
static_cast<jstring>(jenv_->CallObjectMethod(batch, toJson));
|
||||
CheckJvmCall(jinterface, jenv_);
|
||||
char const *c_interface_str =
|
||||
CheckJvmCall(jenv_->GetStringUTFChars(jinterface, nullptr), jenv_);
|
||||
|
||||
StageData(c_interface_str);
|
||||
|
||||
jenv_->ReleaseStringUTFChars(jinterface, c_interface_str);
|
||||
|
||||
last_batch_ = batch;
|
||||
return 1;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
void StageData(std::string interface_str) {
|
||||
++n_batches_;
|
||||
// DataFrame
|
||||
using T = decltype(host_columns_)::value_type::element_type;
|
||||
host_columns_.emplace_back(std::unique_ptr<T>(new T));
|
||||
|
||||
// Stage the meta info.
|
||||
auto json_interface =
|
||||
Json::Load({interface_str.c_str(), interface_str.size()});
|
||||
CHECK(!IsA<Null>(json_interface));
|
||||
StageMetaInfo(json_interface);
|
||||
|
||||
Json features = json_interface["features"];
|
||||
auto json_columns = get<Array const>(features);
|
||||
std::vector<ArrayInterface<1>> interfaces;
|
||||
|
||||
// Stage the data
|
||||
for (auto &json_col : json_columns) {
|
||||
auto column = ArrayInterface<1>(get<Object const>(json_col));
|
||||
interfaces.emplace_back(column);
|
||||
}
|
||||
Json::Dump(features, &interface_str);
|
||||
CopyInterface(interfaces, json_columns, cudaMemcpyDeviceToHost,
|
||||
&host_columns_.back()->data, &host_columns_.back()->valid,
|
||||
&host_columns_.back()->interfaces, copy_stream_);
|
||||
|
||||
XGProxyDMatrixSetDataCudaColumnar(proxy_, interface_str.c_str());
|
||||
it_++;
|
||||
}
|
||||
|
||||
int NextFirstLoop() {
|
||||
try {
|
||||
dh::safe_cuda(cudaStreamSynchronize(copy_stream_));
|
||||
if (this->PullIterFromJVM()) {
|
||||
return 1;
|
||||
} else {
|
||||
initialized_ = true;
|
||||
return 0;
|
||||
}
|
||||
} catch (dmlc::Error const &e) {
|
||||
if (jni_status_ == JNI_EDETACHED) {
|
||||
GlobalJvm()->DetachCurrentThread();
|
||||
}
|
||||
LOG(FATAL) << e.what();
|
||||
}
|
||||
LOG(FATAL) << "Unreachable";
|
||||
return 1;
|
||||
}
|
||||
|
||||
int NextSecondLoop() {
|
||||
std::string str;
|
||||
// Meta
|
||||
auto const &label = this->label_interfaces_.at(it_);
|
||||
Json::Dump(label, &str);
|
||||
XGDMatrixSetInfoFromInterface(proxy_, "label", str.c_str());
|
||||
|
||||
if (n_batches_ == this->weight_interfaces_.size()) {
|
||||
auto const &weight = this->weight_interfaces_.at(it_);
|
||||
Json::Dump(weight, &str);
|
||||
XGDMatrixSetInfoFromInterface(proxy_, "weight", str.c_str());
|
||||
}
|
||||
|
||||
if (n_batches_ == this->margin_interfaces_.size()) {
|
||||
auto const &base_margin = this->margin_interfaces_.at(it_);
|
||||
Json::Dump(base_margin, &str);
|
||||
XGDMatrixSetInfoFromInterface(proxy_, "base_margin", str.c_str());
|
||||
}
|
||||
|
||||
if (n_batches_ == this->qid_interfaces_.size()) {
|
||||
auto const &qid = this->qid_interfaces_.at(it_);
|
||||
Json::Dump(qid, &str);
|
||||
XGDMatrixSetInfoFromInterface(proxy_, "qid", str.c_str());
|
||||
}
|
||||
|
||||
// Data
|
||||
auto const &json_interface = host_columns_.at(it_)->interfaces;
|
||||
|
||||
std::vector<ArrayInterface<1>> in;
|
||||
for (auto interface : json_interface) {
|
||||
auto column = ArrayInterface<1>(get<Object const>(interface));
|
||||
in.emplace_back(column);
|
||||
}
|
||||
std::vector<Json> out;
|
||||
CopyInterface(in, json_interface, cudaMemcpyHostToDevice, &staging_data_,
|
||||
&staging_mask_, &out, nullptr);
|
||||
|
||||
Json temp{Array(std::move(out))};
|
||||
std::string interface_str;
|
||||
Json::Dump(temp, &interface_str);
|
||||
XGProxyDMatrixSetDataCudaColumnar(proxy_, interface_str.c_str());
|
||||
it_++;
|
||||
return 1;
|
||||
}
|
||||
|
||||
int Next() {
|
||||
if (!initialized_) {
|
||||
return NextFirstLoop();
|
||||
} else {
|
||||
if (it_ == n_batches_) {
|
||||
return 0;
|
||||
}
|
||||
return NextSecondLoop();
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
namespace {
|
||||
void Reset(DataIterHandle self) {
|
||||
static_cast<xgboost::jni::DataIteratorProxy *>(self)->Reset();
|
||||
}
|
||||
|
||||
int Next(DataIterHandle self) {
|
||||
return static_cast<xgboost::jni::DataIteratorProxy *>(self)->Next();
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
|
||||
jobject jiter,
|
||||
jfloat jmissing,
|
||||
jint jmax_bin, jint jnthread,
|
||||
jlongArray jout) {
|
||||
xgboost::jni::DataIteratorProxy proxy(jiter);
|
||||
DMatrixHandle result;
|
||||
auto ret = XGDeviceQuantileDMatrixCreateFromCallback(
|
||||
&proxy, proxy.GetDMatrixHandle(), Reset, Next, jmissing, jnthread,
|
||||
jmax_bin, &result);
|
||||
setHandle(jenv, jout, result);
|
||||
return ret;
|
||||
}
|
||||
|
||||
XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls,
|
||||
jobject jdata_iter, jobject jref_iter,
|
||||
char const *config, jlongArray jout) {
|
||||
xgboost::jni::DataIteratorProxy proxy(jdata_iter);
|
||||
DMatrixHandle result;
|
||||
|
||||
std::unique_ptr<xgboost::jni::DataIteratorProxy> ref_proxy{nullptr};
|
||||
if (jref_iter) {
|
||||
ref_proxy = std::make_unique<xgboost::jni::DataIteratorProxy>(jref_iter);
|
||||
}
|
||||
auto ret = XGQuantileDMatrixCreateFromCallback(
|
||||
&proxy, proxy.GetDMatrixHandle(), ref_proxy.get(), Reset, Next, config, &result);
|
||||
setHandle(jenv, jout, result);
|
||||
return ret;
|
||||
}
|
||||
} // namespace jni
|
||||
} // namespace xgboost
|
||||
Reference in New Issue
Block a user