Use Booster context in DMatrix. (#8896)
- Pass context from booster to DMatrix. - Use context instead of integer for `n_threads`. - Check the consistency configuration for `max_bin`. - Test for all combinations of initialization options.
This commit is contained in:
@@ -23,7 +23,7 @@ std::string UriSVM(std::string name, std::string cache) {
|
||||
} // namespace
|
||||
|
||||
template <typename Page>
|
||||
void TestSparseDMatrixLoadFile() {
|
||||
void TestSparseDMatrixLoadFile(Context const* ctx) {
|
||||
dmlc::TemporaryDirectory tmpdir;
|
||||
auto opath = tmpdir.path + "/1-based.svm";
|
||||
CreateBigTestData(opath, 3 * 64, false);
|
||||
@@ -48,7 +48,7 @@ void TestSparseDMatrixLoadFile() {
|
||||
data::SimpleDMatrix simple{&adapter, std::numeric_limits<float>::quiet_NaN(),
|
||||
1};
|
||||
Page out;
|
||||
for (auto const& page : m.GetBatches<Page>()) {
|
||||
for (auto const &page : m.GetBatches<Page>(ctx)) {
|
||||
if (std::is_same<Page, SparsePage>::value) {
|
||||
out.Push(page);
|
||||
} else {
|
||||
@@ -58,7 +58,7 @@ void TestSparseDMatrixLoadFile() {
|
||||
ASSERT_EQ(m.Info().num_col_, simple.Info().num_col_);
|
||||
ASSERT_EQ(m.Info().num_row_, simple.Info().num_row_);
|
||||
|
||||
for (auto const& page : simple.GetBatches<Page>()) {
|
||||
for (auto const& page : simple.GetBatches<Page>(ctx)) {
|
||||
ASSERT_EQ(page.offset.HostVector(), out.offset.HostVector());
|
||||
for (size_t i = 0; i < page.data.Size(); ++i) {
|
||||
ASSERT_EQ(page.data.HostVector()[i].fvalue, out.data.HostVector()[i].fvalue);
|
||||
@@ -67,16 +67,18 @@ void TestSparseDMatrixLoadFile() {
|
||||
}
|
||||
|
||||
TEST(SparsePageDMatrix, LoadFile) {
|
||||
TestSparseDMatrixLoadFile<SparsePage>();
|
||||
TestSparseDMatrixLoadFile<CSCPage>();
|
||||
TestSparseDMatrixLoadFile<SortedCSCPage>();
|
||||
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||
TestSparseDMatrixLoadFile<SparsePage>(&ctx);
|
||||
TestSparseDMatrixLoadFile<CSCPage>(&ctx);
|
||||
TestSparseDMatrixLoadFile<SortedCSCPage>(&ctx);
|
||||
}
|
||||
|
||||
// allow caller to retain pages so they can process multiple pages at the same time.
|
||||
template <typename Page>
|
||||
void TestRetainPage() {
|
||||
auto m = CreateSparsePageDMatrix(10000);
|
||||
auto batches = m->GetBatches<Page>();
|
||||
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||
auto batches = m->GetBatches<Page>(&ctx);
|
||||
auto begin = batches.begin();
|
||||
auto end = batches.end();
|
||||
|
||||
@@ -100,7 +102,7 @@ void TestRetainPage() {
|
||||
}
|
||||
|
||||
// make sure it's const and the caller can not modify the content of page.
|
||||
for (auto& page : m->GetBatches<Page>()) {
|
||||
for (auto &page : m->GetBatches<Page>({&ctx})) {
|
||||
static_assert(std::is_const<std::remove_reference_t<decltype(page)>>::value);
|
||||
}
|
||||
}
|
||||
@@ -143,10 +145,11 @@ TEST(SparsePageDMatrix, ColAccess) {
|
||||
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
||||
CreateSimpleTestData(tmp_file);
|
||||
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(UriSVM(tmp_file, tmp_file));
|
||||
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||
|
||||
// Loop over the batches and assert the data is as expected
|
||||
size_t iter = 0;
|
||||
for (auto const &col_batch : dmat->GetBatches<xgboost::SortedCSCPage>()) {
|
||||
for (auto const &col_batch : dmat->GetBatches<xgboost::SortedCSCPage>(&ctx)) {
|
||||
auto col_page = col_batch.GetView();
|
||||
ASSERT_EQ(col_page.Size(), dmat->Info().num_col_);
|
||||
if (iter == 1) {
|
||||
@@ -164,7 +167,7 @@ TEST(SparsePageDMatrix, ColAccess) {
|
||||
|
||||
// Loop over the batches and assert the data is as expected
|
||||
iter = 0;
|
||||
for (auto const &col_batch : dmat->GetBatches<xgboost::CSCPage>()) {
|
||||
for (auto const &col_batch : dmat->GetBatches<xgboost::CSCPage>(&ctx)) {
|
||||
auto col_page = col_batch.GetView();
|
||||
EXPECT_EQ(col_page.Size(), dmat->Info().num_col_);
|
||||
if (iter == 0) {
|
||||
@@ -182,9 +185,9 @@ TEST(SparsePageDMatrix, ColAccess) {
|
||||
TEST(SparsePageDMatrix, ThreadSafetyException) {
|
||||
size_t constexpr kEntriesPerCol = 3;
|
||||
size_t constexpr kEntries = 64 * kEntriesPerCol * 2;
|
||||
Context ctx;
|
||||
|
||||
std::unique_ptr<xgboost::DMatrix> dmat =
|
||||
xgboost::CreateSparsePageDMatrix(kEntries);
|
||||
std::unique_ptr<xgboost::DMatrix> dmat = xgboost::CreateSparsePageDMatrix(kEntries);
|
||||
|
||||
int threads = 1000;
|
||||
|
||||
@@ -221,7 +224,8 @@ TEST(SparsePageDMatrix, ColAccessBatches) {
|
||||
// Create multiple sparse pages
|
||||
std::unique_ptr<xgboost::DMatrix> dmat{xgboost::CreateSparsePageDMatrix(kEntries)};
|
||||
ASSERT_EQ(dmat->Ctx()->Threads(), AllThreadsForTest());
|
||||
for (auto const &page : dmat->GetBatches<xgboost::CSCPage>()) {
|
||||
auto ctx = CreateEmptyGenericParam(Context::kCpuId);
|
||||
for (auto const &page : dmat->GetBatches<xgboost::CSCPage>(&ctx)) {
|
||||
ASSERT_EQ(dmat->Info().num_col_, page.Size());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user