/*! * Copyright 2020-2022, XGBoost contributors */ #ifndef XGBOOST_DATA_PROXY_DMATRIX_H_ #define XGBOOST_DATA_PROXY_DMATRIX_H_ #include #include #include #include #include "xgboost/data.h" #include "xgboost/generic_parameters.h" #include "xgboost/c_api.h" #include "adapter.h" namespace xgboost { namespace data { /* * \brief A proxy to external iterator. */ template class DataIterProxy { DataIterHandle iter_; ResetFn* reset_; NextFn* next_; public: DataIterProxy(DataIterHandle iter, ResetFn* reset, NextFn* next) : iter_{iter}, reset_{reset}, next_{next} {} bool Next() { return next_(iter_); } void Reset() { reset_(iter_); } }; /* * \brief A proxy of DMatrix used by external iterator. */ class DMatrixProxy : public DMatrix { MetaInfo info_; dmlc::any batch_; Context ctx_; #if defined(XGBOOST_USE_CUDA) void FromCudaColumnar(std::string interface_str); void FromCudaArray(std::string interface_str); #endif // defined(XGBOOST_USE_CUDA) public: int DeviceIdx() const { return ctx_.gpu_id; } void SetData(char const* c_interface) { common::AssertGPUSupport(); #if defined(XGBOOST_USE_CUDA) std::string interface_str = c_interface; Json json_array_interface = Json::Load({interface_str.c_str(), interface_str.size()}); if (IsA(json_array_interface)) { this->FromCudaColumnar(interface_str); } else { this->FromCudaArray(interface_str); } if (this->info_.num_row_ == 0) { this->ctx_.gpu_id = Context::kCpuId; } #endif // defined(XGBOOST_USE_CUDA) } void SetArrayData(char const* c_interface); void SetCSRData(char const *c_indptr, char const *c_indices, char const *c_values, bst_feature_t n_features, bool on_host); MetaInfo& Info() override { return info_; } MetaInfo const& Info() const override { return info_; } Context const* Ctx() const override { return &ctx_; } bool SingleColBlock() const override { return true; } bool EllpackExists() const override { return true; } bool SparsePageExists() const override { return false; } DMatrix *Slice(common::Span ridxs) override { LOG(FATAL) << "Slicing DMatrix is not supported for Proxy DMatrix."; return nullptr; } BatchSet GetRowBatches() override { LOG(FATAL) << "Not implemented."; return BatchSet(BatchIterator(nullptr)); } BatchSet GetColumnBatches() override { LOG(FATAL) << "Not implemented."; return BatchSet(BatchIterator(nullptr)); } BatchSet GetSortedColumnBatches() override { LOG(FATAL) << "Not implemented."; return BatchSet(BatchIterator(nullptr)); } BatchSet GetEllpackBatches(const BatchParam& param) override { LOG(FATAL) << "Not implemented."; return BatchSet(BatchIterator(nullptr)); } BatchSet GetGradientIndex(const BatchParam&) override { LOG(FATAL) << "Not implemented."; return BatchSet(BatchIterator(nullptr)); } dmlc::any Adapter() const { return batch_; } }; inline DMatrixProxy *MakeProxy(DMatrixHandle proxy) { auto proxy_handle = static_cast *>(proxy); CHECK(proxy_handle) << "Invalid proxy handle."; DMatrixProxy *typed = static_cast(proxy_handle->get()); return typed; } template decltype(auto) HostAdapterDispatch(DMatrixProxy const* proxy, Fn fn, bool* type_error = nullptr) { if (proxy->Adapter().type() == typeid(std::shared_ptr)) { auto value = dmlc::get>(proxy->Adapter())->Value(); if (type_error) { *type_error = false; } return fn(value); } else if (proxy->Adapter().type() == typeid(std::shared_ptr)) { auto value = dmlc::get>( proxy->Adapter())->Value(); if (type_error) { *type_error = false; } return fn(value); } else { if (type_error) { *type_error = true; } else { LOG(FATAL) << "Unknown type: " << proxy->Adapter().type().name(); } return std::result_of_t>()->Value()))>(); } } } // namespace data } // namespace xgboost #endif // XGBOOST_DATA_PROXY_DMATRIX_H_