External memory support for hist (#7531)

* Generate column matrix from gHistIndex.
* Avoid synchronization with the sparse page once the cache is written.
* Cleanups: Remove member variables/functions, change the update routine to look like approx and gpu_hist.
* Remove pruner.
This commit is contained in:
Jiaming Yuan
2022-03-22 00:13:20 +08:00
committed by GitHub
parent cd55823112
commit 4d81c741e9
25 changed files with 563 additions and 686 deletions

View File

@@ -3,6 +3,7 @@
*/
#include <gtest/gtest.h>
#include "../../../src/common/column_matrix.h"
#include "../../../src/data/gradient_index.h"
#include "../../../src/data/sparse_page_source.h"
#include "../helpers.h"
@@ -15,33 +16,31 @@ TEST(GHistIndexPageRawFormat, IO) {
auto m = RandomDataGenerator{100, 14, 0.5}.GenerateDMatrix();
dmlc::TemporaryDirectory tmpdir;
std::string path = tmpdir.path + "/ghistindex.page";
auto batch = BatchParam{256, 0.5};
{
std::unique_ptr<dmlc::Stream> fo{dmlc::Stream::Create(path.c_str(), "w")};
for (auto const &index :
m->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, 256})) {
for (auto const &index : m->GetBatches<GHistIndexMatrix>(batch)) {
format->Write(index, fo.get());
}
}
GHistIndexMatrix page;
std::unique_ptr<dmlc::SeekStream> fi{
dmlc::SeekStream::CreateForRead(path.c_str())};
std::unique_ptr<dmlc::SeekStream> fi{dmlc::SeekStream::CreateForRead(path.c_str())};
format->Read(&page, fi.get());
for (auto const &gidx :
m->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, 256})) {
for (auto const &gidx : m->GetBatches<GHistIndexMatrix>(batch)) {
auto const &loaded = gidx;
ASSERT_EQ(loaded.cut.Ptrs(), page.cut.Ptrs());
ASSERT_EQ(loaded.cut.MinValues(), page.cut.MinValues());
ASSERT_EQ(loaded.cut.Values(), page.cut.Values());
ASSERT_EQ(loaded.base_rowid, page.base_rowid);
ASSERT_EQ(loaded.IsDense(), page.IsDense());
ASSERT_TRUE(std::equal(loaded.index.begin(), loaded.index.end(),
page.index.begin()));
ASSERT_TRUE(std::equal(loaded.index.Offset(),
loaded.index.Offset() + loaded.index.OffsetSize(),
ASSERT_TRUE(std::equal(loaded.index.begin(), loaded.index.end(), page.index.begin()));
ASSERT_TRUE(std::equal(loaded.index.Offset(), loaded.index.Offset() + loaded.index.OffsetSize(),
page.index.Offset()));
ASSERT_EQ(loaded.Transpose().GetTypeSize(), loaded.Transpose().GetTypeSize());
}
}
} // namespace data