Implement iterative DMatrix. (#5837)
This commit is contained in:
@@ -23,6 +23,7 @@
|
||||
#include "../common/io.h"
|
||||
#include "../data/adapter.h"
|
||||
#include "../data/simple_dmatrix.h"
|
||||
#include "../data/proxy_dmatrix.h"
|
||||
|
||||
using namespace xgboost; // NOLINT(*);
|
||||
|
||||
@@ -101,6 +102,50 @@ XGB_DLL int XGDMatrixCreateFromArrayInterface(char const* c_json_strs,
|
||||
|
||||
#endif
|
||||
|
||||
// Create from data iterator
|
||||
XGB_DLL int XGProxyDMatrixCreate(DMatrixHandle* out) {
|
||||
API_BEGIN();
|
||||
*out = new std::shared_ptr<xgboost::DMatrix>(new xgboost::data::DMatrixProxy);;
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int
|
||||
XGDeviceQuantileDMatrixSetDataCudaArrayInterface(DMatrixHandle handle,
|
||||
char const *c_interface_str) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto p_m = static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
|
||||
CHECK(p_m);
|
||||
auto m = static_cast<xgboost::data::DMatrixProxy*>(p_m->get());
|
||||
CHECK(m) << "Current DMatrix type does not support set data.";
|
||||
m->SetData(c_interface_str);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int
|
||||
XGDeviceQuantileDMatrixSetDataCudaColumnar(DMatrixHandle handle,
|
||||
char const *c_interface_str) {
|
||||
API_BEGIN();
|
||||
CHECK_HANDLE();
|
||||
auto p_m = static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
|
||||
CHECK(p_m);
|
||||
auto m = static_cast<xgboost::data::DMatrixProxy*>(p_m->get());
|
||||
CHECK(m) << "Current DMatrix type does not support set data.";
|
||||
m->SetData(c_interface_str);
|
||||
API_END();
|
||||
}
|
||||
|
||||
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(
|
||||
DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset,
|
||||
XGDMatrixCallbackNext *next, float missing, int nthread,
|
||||
int max_bin, DMatrixHandle *out) {
|
||||
API_BEGIN();
|
||||
*out = new std::shared_ptr<xgboost::DMatrix>{
|
||||
xgboost::DMatrix::Create(iter, proxy, reset, next, missing, nthread, max_bin)};
|
||||
API_END();
|
||||
}
|
||||
// End Create from data iterator
|
||||
|
||||
XGB_DLL int XGDMatrixCreateFromCSREx(const size_t* indptr,
|
||||
const unsigned* indices,
|
||||
const bst_float* data,
|
||||
|
||||
Reference in New Issue
Block a user