Implement iterative DMatrix. (#5837)

This commit is contained in:
Jiaming Yuan
2020-07-03 11:44:52 +08:00
committed by GitHub
parent 4d277d750d
commit 1a0801238e
15 changed files with 855 additions and 84 deletions

View File

@@ -19,6 +19,7 @@
#include "../common/version.h"
#include "../common/group_data.h"
#include "../data/adapter.h"
#include "../data/iterative_device_dmatrix.h"
#if DMLC_ENABLE_STD_THREAD
#include "./sparse_page_source.h"
@@ -569,6 +570,26 @@ DMatrix* DMatrix::Load(const std::string& uri,
}
return dmat;
}
template <typename DataIterHandle, typename DMatrixHandle,
typename DataIterResetCallback, typename XGDMatrixCallbackNext>
DMatrix *DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy,
DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing,
int nthread,
int max_bin) {
#if defined(XGBOOST_USE_CUDA)
return new data::IterativeDeviceDMatrix(iter, proxy, reset, next, missing, nthread, max_bin);
#else
common::AssertGPUSupport();
return nullptr;
#endif
}
template DMatrix *DMatrix::Create<DataIterHandle, DMatrixHandle,
DataIterResetCallback, XGDMatrixCallbackNext>(
DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing, int nthread,
int max_bin);
template <typename AdapterT>
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread,