add a test for cpu predictor using external memory (#4308)

* add a test for cpu predictor using external memory

* allow different page size for testing
This commit is contained in:
Rong Ou
2019-04-09 18:25:10 -07:00
committed by Rory Mitchell
parent b72eab3e07
commit 81c1cd40ca
7 changed files with 95 additions and 20 deletions

View File

@@ -150,7 +150,8 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
DMatrix* DMatrix::Load(const std::string& uri,
bool silent,
bool load_row_split,
const std::string& file_format) {
const std::string& file_format,
const size_t page_size) {
std::string fname, cache_file;
size_t dlm_pos = uri.find('#');
if (dlm_pos != std::string::npos) {
@@ -217,7 +218,7 @@ DMatrix* DMatrix::Load(const std::string& uri,
std::unique_ptr<dmlc::Parser<uint32_t> > parser(
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart, file_format.c_str()));
DMatrix* dmat = DMatrix::Create(parser.get(), cache_file);
DMatrix* dmat = DMatrix::Create(parser.get(), cache_file, page_size);
if (!silent) {
LOG(CONSOLE) << dmat->Info().num_row_ << 'x' << dmat->Info().num_col_ << " matrix with "
<< dmat->Info().num_nonzero_ << " entries loaded from " << uri;
@@ -248,7 +249,8 @@ DMatrix* DMatrix::Load(const std::string& uri,
}
DMatrix* DMatrix::Create(dmlc::Parser<uint32_t>* parser,
const std::string& cache_prefix) {
const std::string& cache_prefix,
const size_t page_size) {
if (cache_prefix.length() == 0) {
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
source->CopyFrom(parser);
@@ -256,7 +258,7 @@ DMatrix* DMatrix::Create(dmlc::Parser<uint32_t>* parser,
} else {
#if DMLC_ENABLE_STD_THREAD
if (!data::SparsePageSource::CacheExist(cache_prefix, ".row.page")) {
data::SparsePageSource::CreateRowPage(parser, cache_prefix);
data::SparsePageSource::CreateRowPage(parser, cache_prefix, page_size);
}
std::unique_ptr<data::SparsePageSource> source(
new data::SparsePageSource(cache_prefix, ".row.page"));

View File

@@ -40,9 +40,6 @@ class SparsePageDMatrix : public DMatrix {
bool SingleColBlock() const override;
private:
/*! \brief page size 256 MB */
static const size_t kPageSize = 256UL << 20UL;
// source data pointers.
std::unique_ptr<DataSource> row_source_;
std::unique_ptr<SparsePageSource> column_source_;

View File

@@ -126,7 +126,8 @@ bool SparsePageSource::CacheExist(const std::string& cache_info,
}
void SparsePageSource::CreateRowPage(dmlc::Parser<uint32_t>* src,
const std::string& cache_info) {
const std::string& cache_info,
const size_t page_size) {
const std::string page_type = ".row.page";
std::vector<std::string> cache_shards = GetCacheShards(cache_info);
CHECK_NE(cache_shards.size(), 0U);
@@ -183,7 +184,7 @@ void SparsePageSource::CreateRowPage(dmlc::Parser<uint32_t>* src,
static_cast<uint64_t>(index + 1));
}
page->Push(batch);
if (page->MemCostBytes() >= kPageSize) {
if (page->MemCostBytes() >= page_size) {
bytes_write += page->MemCostBytes();
writer.PushWrite(std::move(page));
writer.Alloc(&page);
@@ -222,7 +223,8 @@ void SparsePageSource::CreateRowPage(dmlc::Parser<uint32_t>* src,
void SparsePageSource::CreatePageFromDMatrix(DMatrix* src,
const std::string& cache_info,
const std::string& page_type) {
const std::string& page_type,
const size_t page_size) {
std::vector<std::string> cache_shards = GetCacheShards(cache_info);
CHECK_NE(cache_shards.size(), 0U);
// read in the info files.
@@ -254,7 +256,7 @@ void SparsePageSource::CreatePageFromDMatrix(DMatrix* src,
LOG(FATAL) << "Unknown page type: " << page_type;
}
if (page->MemCostBytes() >= kPageSize) {
if (page->MemCostBytes() >= page_size) {
bytes_write += page->MemCostBytes();
writer.PushWrite(std::move(page));
writer.Alloc(&page);

View File

@@ -48,9 +48,11 @@ class SparsePageSource : public DataSource {
* \brief Create source by taking data from parser.
* \param src source parser.
* \param cache_info The cache_info of cache file location.
* \param page_size Page size for external memory.
*/
static void CreateRowPage(dmlc::Parser<uint32_t>* src,
const std::string& cache_info);
const std::string& cache_info,
const size_t page_size = DMatrix::kPageSize);
/*!
* \brief Create source cache by copy content from DMatrix.
* \param cache_info The cache_info of cache file location.
@@ -73,14 +75,13 @@ class SparsePageSource : public DataSource {
*/
static bool CacheExist(const std::string& cache_info,
const std::string& page_type);
/*! \brief page size 32 MB */
static const size_t kPageSize = 32UL << 20UL;
/*! \brief magic number used to identify Page */
static const int kMagic = 0xffffab02;
private:
static void CreatePageFromDMatrix(DMatrix* src, const std::string& cache_info,
const std::string& page_type);
const std::string& page_type,
const size_t page_size = DMatrix::kPageSize);
/*! \brief number of rows */
size_t base_rowid_;
/*! \brief page currently on hold. */