#include #include "../../../../src/common/device_helpers.cuh" #include "../../../../src/common/cuda_pinned_allocator.h" #include "../../../../src/data/array_interface.h" #include "jvm_utils.h" #include namespace xgboost { namespace jni { template T const *RawPtr(std::vector const &data) { return data.data(); } template T *RawPtr(std::vector &data) { return data.data(); } template T const *RawPtr(dh::device_vector const &data) { return data.data().get(); } template T *RawPtr(dh::device_vector &data) { return data.data().get(); } template T CheckJvmCall(T const &v, JNIEnv *jenv) { if (!v) { CHECK(jenv->ExceptionOccurred()); jenv->ExceptionDescribe(); } return v; } template void CopyColumnMask(xgboost::ArrayInterface<1> const &interface, std::vector const &columns, cudaMemcpyKind kind, size_t c, VCont *p_mask, Json *p_out, cudaStream_t stream) { auto &mask = *p_mask; auto &out = *p_out; auto size = sizeof(typename VCont::value_type) * interface.n; mask.resize(size); CHECK(RawPtr(mask)); CHECK(size); CHECK(interface.valid.Data()); dh::safe_cuda( cudaMemcpyAsync(RawPtr(mask), interface.valid.Data(), size, kind, stream)); auto const &mask_column = columns[c]["mask"]; out["mask"] = Object(); std::vector mask_data{ Json{reinterpret_cast(RawPtr(mask))}, Json{get(mask_column["data"][1])}}; out["mask"]["data"] = Array(std::move(mask_data)); if (get(mask_column["shape"]).size() == 2) { std::vector mask_shape{ Json{get(mask_column["shape"][0])}, Json{get(mask_column["shape"][1])}}; out["mask"]["shape"] = Array(std::move(mask_shape)); } else if (get(mask_column["shape"]).size() == 1) { std::vector mask_shape{ Json{get(mask_column["shape"][0])}}; out["mask"]["shape"] = Array(std::move(mask_shape)); } else { LOG(FATAL) << "Invalid shape of mask"; } out["mask"]["typestr"] = String(" void CopyInterface(std::vector> &interface_arr, std::vector const &columns, cudaMemcpyKind kind, std::vector *p_data, std::vector *p_mask, std::vector *p_out, cudaStream_t stream) { p_data->resize(interface_arr.size()); p_mask->resize(interface_arr.size()); p_out->resize(interface_arr.size()); for (size_t c = 0; c < interface_arr.size(); ++c) { auto &interface = interface_arr.at(c); size_t element_size = interface.ElementSize(); size_t size = element_size * interface.n; auto &data = (*p_data)[c]; auto &mask = (*p_mask)[c]; data.resize(size); dh::safe_cuda(cudaMemcpyAsync(RawPtr(data), interface.data, size, kind, stream)); auto &out = (*p_out)[c]; out = Object(); std::vector j_data{ Json{Integer(reinterpret_cast(RawPtr(data)))}, Json{Boolean{false}}}; out["data"] = Array(std::move(j_data)); out["shape"] = Array(std::vector{Json(Integer(interface.Shape(0)))}); if (interface.valid.Data()) { CopyColumnMask(interface, columns, kind, c, &mask, &out, stream); } out["typestr"] = String(" void CopyMetaInfo(Json *p_interface, dh::device_vector *out, cudaStream_t stream) { auto &j_interface = *p_interface; CHECK_EQ(get(j_interface).size(), 1); auto object = get(get(j_interface)[0]); ArrayInterface<1> interface(object); out->resize(interface.Shape(0)); size_t element_size = interface.ElementSize(); size_t size = element_size * interface.n; dh::safe_cuda(cudaMemcpyAsync(RawPtr(*out), interface.data, size, cudaMemcpyDeviceToDevice, stream)); j_interface[0]["data"][0] = reinterpret_cast(RawPtr(*out)); } template struct DataFrame { std::vector data; std::vector valid; std::vector interfaces; }; class DataIteratorProxy { DMatrixHandle proxy_; JNIEnv *jenv_; int jni_status_; jobject jiter_; bool cache_on_host_{true}; // TODO(Bobby): Make this optional. template using Alloc = xgboost::common::cuda_impl::pinned_allocator; template using HostVector = std::vector>; // This vector is created for staging device data on host to save GPU memory. // When space is not of concern, we can stage them on device memory directly. std::vector< std::unique_ptr, HostVector>>> host_columns_; // TODO(Bobby): Use this instead of `host_columns_` if staging is not // required. std::vector, dh::device_vector>>> device_columns_; // Staging area for metainfo. // TODO(Bobby): label_upper_bound, label_lower_bound, group. std::vector>> labels_; std::vector>> weights_; std::vector>> base_margins_; std::vector>> qids_; std::vector label_interfaces_; std::vector weight_interfaces_; std::vector margin_interfaces_; std::vector qid_interfaces_; size_t it_{0}; size_t n_batches_{0}; bool initialized_{false}; jobject last_batch_ {nullptr}; // Temp buffer on device, each `dh::device_vector` represents a column // from cudf. std::vector> staging_data_; std::vector> staging_mask_; cudaStream_t copy_stream_; public: explicit DataIteratorProxy(jobject jiter, bool cache_on_host = true) : jiter_{jiter}, cache_on_host_{cache_on_host} { XGProxyDMatrixCreate(&proxy_); jni_status_ = GlobalJvm()->GetEnv(reinterpret_cast(&jenv_), JNI_VERSION_1_6); this->Reset(); dh::safe_cuda(cudaStreamCreateWithFlags(©_stream_, cudaStreamNonBlocking)); } ~DataIteratorProxy() { XGDMatrixFree(proxy_); dh::safe_cuda(cudaStreamDestroy(copy_stream_)); } DMatrixHandle GetDMatrixHandle() const { return proxy_; } // Helper function for staging meta info. void StageMetaInfo(Json json_interface) { CHECK(!IsA(json_interface)); auto json_map = get(json_interface); if (json_map.find("label") == json_map.cend()) { LOG(FATAL) << "Must have a label field."; } Json label = json_interface["label"]; CHECK(!IsA(label)); labels_.emplace_back(new dh::device_vector); CopyMetaInfo(&label, labels_.back().get(), copy_stream_); label_interfaces_.emplace_back(label); std::string str; Json::Dump(label, &str); XGDMatrixSetInfoFromInterface(proxy_, "label", str.c_str()); if (json_map.find("weight") != json_map.cend()) { Json weight = json_interface["weight"]; CHECK(!IsA(weight)); weights_.emplace_back(new dh::device_vector); CopyMetaInfo(&weight, weights_.back().get(), copy_stream_); weight_interfaces_.emplace_back(weight); Json::Dump(weight, &str); XGDMatrixSetInfoFromInterface(proxy_, "weight", str.c_str()); } if (json_map.find("baseMargin") != json_map.cend()) { Json basemargin = json_interface["baseMargin"]; base_margins_.emplace_back(new dh::device_vector); CopyMetaInfo(&basemargin, base_margins_.back().get(), copy_stream_); margin_interfaces_.emplace_back(basemargin); Json::Dump(basemargin, &str); XGDMatrixSetInfoFromInterface(proxy_, "base_margin", str.c_str()); } if (json_map.find("qid") != json_map.cend()) { Json qid = json_interface["qid"]; qids_.emplace_back(new dh::device_vector); CopyMetaInfo(&qid, qids_.back().get(), copy_stream_); qid_interfaces_.emplace_back(qid); Json::Dump(qid, &str); XGDMatrixSetInfoFromInterface(proxy_, "qid", str.c_str()); } } void CloseJvmBatch() { if (last_batch_) { jclass batch_class = CheckJvmCall(jenv_->GetObjectClass(last_batch_), jenv_); jmethodID closeMethod = CheckJvmCall(jenv_->GetMethodID(batch_class, "close", "()V"), jenv_); jenv_->CallVoidMethod(last_batch_, closeMethod); last_batch_ = nullptr; } } void Reset() { it_ = 0; this->CloseJvmBatch(); } int32_t PullIterFromJVM() { jclass iterClass = jenv_->FindClass("java/util/Iterator"); this->CloseJvmBatch(); jmethodID has_next = CheckJvmCall(jenv_->GetMethodID(iterClass, "hasNext", "()Z"), jenv_); jmethodID next = CheckJvmCall( jenv_->GetMethodID(iterClass, "next", "()Ljava/lang/Object;"), jenv_); if (jenv_->CallBooleanMethod(jiter_, has_next)) { // batch should be ColumnBatch from jvm jobject batch = CheckJvmCall(jenv_->CallObjectMethod(jiter_, next), jenv_); jclass batch_class = CheckJvmCall(jenv_->GetObjectClass(batch), jenv_); jmethodID toJson = CheckJvmCall(jenv_->GetMethodID( batch_class, "toJson", "()Ljava/lang/String;"), jenv_); auto jinterface = static_cast(jenv_->CallObjectMethod(batch, toJson)); CheckJvmCall(jinterface, jenv_); char const *c_interface_str = CheckJvmCall(jenv_->GetStringUTFChars(jinterface, nullptr), jenv_); StageData(c_interface_str); jenv_->ReleaseStringUTFChars(jinterface, c_interface_str); last_batch_ = batch; return 1; } else { return 0; } } void StageData(std::string interface_str) { ++n_batches_; // DataFrame using T = decltype(host_columns_)::value_type::element_type; host_columns_.emplace_back(std::unique_ptr(new T)); // Stage the meta info. auto json_interface = Json::Load({interface_str.c_str(), interface_str.size()}); CHECK(!IsA(json_interface)); StageMetaInfo(json_interface); Json features = json_interface["features"]; auto json_columns = get(features); std::vector> interfaces; // Stage the data for (auto &json_col : json_columns) { auto column = ArrayInterface<1>(get(json_col)); interfaces.emplace_back(column); } Json::Dump(features, &interface_str); CopyInterface(interfaces, json_columns, cudaMemcpyDeviceToHost, &host_columns_.back()->data, &host_columns_.back()->valid, &host_columns_.back()->interfaces, copy_stream_); XGProxyDMatrixSetDataCudaColumnar(proxy_, interface_str.c_str()); it_++; } int NextFirstLoop() { try { dh::safe_cuda(cudaStreamSynchronize(copy_stream_)); if (this->PullIterFromJVM()) { return 1; } else { initialized_ = true; return 0; } } catch (dmlc::Error const &e) { if (jni_status_ == JNI_EDETACHED) { GlobalJvm()->DetachCurrentThread(); } LOG(FATAL) << e.what(); } LOG(FATAL) << "Unreachable"; return 1; } int NextSecondLoop() { std::string str; // Meta auto const &label = this->label_interfaces_.at(it_); Json::Dump(label, &str); XGDMatrixSetInfoFromInterface(proxy_, "label", str.c_str()); if (n_batches_ == this->weight_interfaces_.size()) { auto const &weight = this->weight_interfaces_.at(it_); Json::Dump(weight, &str); XGDMatrixSetInfoFromInterface(proxy_, "weight", str.c_str()); } if (n_batches_ == this->margin_interfaces_.size()) { auto const &base_margin = this->margin_interfaces_.at(it_); Json::Dump(base_margin, &str); XGDMatrixSetInfoFromInterface(proxy_, "base_margin", str.c_str()); } if (n_batches_ == this->qid_interfaces_.size()) { auto const &qid = this->qid_interfaces_.at(it_); Json::Dump(qid, &str); XGDMatrixSetInfoFromInterface(proxy_, "qid", str.c_str()); } // Data auto const &json_interface = host_columns_.at(it_)->interfaces; std::vector> in; for (auto interface : json_interface) { auto column = ArrayInterface<1>(get(interface)); in.emplace_back(column); } std::vector out; CopyInterface(in, json_interface, cudaMemcpyHostToDevice, &staging_data_, &staging_mask_, &out, nullptr); Json temp{Array(std::move(out))}; std::string interface_str; Json::Dump(temp, &interface_str); XGProxyDMatrixSetDataCudaColumnar(proxy_, interface_str.c_str()); it_++; return 1; } int Next() { if (!initialized_) { return NextFirstLoop(); } else { if (it_ == n_batches_) { return 0; } return NextSecondLoop(); } }; }; namespace { void Reset(DataIterHandle self) { static_cast(self)->Reset(); } int Next(DataIterHandle self) { return static_cast(self)->Next(); } } // anonymous namespace XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls, jobject jiter, jfloat jmissing, jint jmax_bin, jint jnthread, jlongArray jout) { xgboost::jni::DataIteratorProxy proxy(jiter); DMatrixHandle result; auto ret = XGDeviceQuantileDMatrixCreateFromCallback( &proxy, proxy.GetDMatrixHandle(), Reset, Next, jmissing, jnthread, jmax_bin, &result); setHandle(jenv, jout, result); return ret; } XGB_DLL int XGQuantileDMatrixCreateFromCallbackImpl(JNIEnv *jenv, jclass jcls, jobject jdata_iter, jobject jref_iter, char const *config, jlongArray jout) { xgboost::jni::DataIteratorProxy proxy(jdata_iter); DMatrixHandle result; std::unique_ptr ref_proxy{nullptr}; if (jref_iter) { ref_proxy = std::make_unique(jref_iter); } auto ret = XGQuantileDMatrixCreateFromCallback( &proxy, proxy.GetDMatrixHandle(), ref_proxy.get(), Reset, Next, config, &result); setHandle(jenv, jout, result); return ret; } } // namespace jni } // namespace xgboost