Use Booster context in DMatrix. (#8896)
- Pass context from booster to DMatrix. - Use context instead of integer for `n_threads`. - Check the consistency configuration for `max_bin`. - Test for all combinations of initialization options.
This commit is contained in:
@@ -1,15 +1,20 @@
|
||||
/**
|
||||
* Copyright 2019-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include <xgboost/data.h> // for DMatrix
|
||||
|
||||
#include "../../../src/common/compressed_iterator.h"
|
||||
#include "../../../src/data/ellpack_page.cuh"
|
||||
#include "../../../src/data/sparse_page_dmatrix.h"
|
||||
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
||||
#include "../../../src/tree/param.h" // TrainParam
|
||||
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
||||
#include "../helpers.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
TEST(SparsePageDMatrix, EllpackPage) {
|
||||
Context ctx{MakeCUDACtx(0)};
|
||||
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
||||
dmlc::TemporaryDirectory tempdir;
|
||||
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
||||
CreateSimpleTestData(tmp_file);
|
||||
@@ -17,7 +22,7 @@ TEST(SparsePageDMatrix, EllpackPage) {
|
||||
|
||||
// Loop over the batches and assert the data is as expected
|
||||
size_t n = 0;
|
||||
for (const auto& batch : dmat->GetBatches<EllpackPage>({0, 256})) {
|
||||
for (const auto& batch : dmat->GetBatches<EllpackPage>(&ctx, param)) {
|
||||
n += batch.Size();
|
||||
}
|
||||
EXPECT_EQ(n, dmat->Info().num_row_);
|
||||
@@ -37,6 +42,8 @@ TEST(SparsePageDMatrix, EllpackPage) {
|
||||
}
|
||||
|
||||
TEST(SparsePageDMatrix, MultipleEllpackPages) {
|
||||
Context ctx{MakeCUDACtx(0)};
|
||||
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
||||
dmlc::TemporaryDirectory tmpdir;
|
||||
std::string filename = tmpdir.path + "/big.libsvm";
|
||||
size_t constexpr kPageSize = 64, kEntriesPerCol = 3;
|
||||
@@ -46,7 +53,7 @@ TEST(SparsePageDMatrix, MultipleEllpackPages) {
|
||||
// Loop over the batches and count the records
|
||||
int64_t batch_count = 0;
|
||||
int64_t row_count = 0;
|
||||
for (const auto& batch : dmat->GetBatches<EllpackPage>({0, 256})) {
|
||||
for (const auto& batch : dmat->GetBatches<EllpackPage>(&ctx, param)) {
|
||||
EXPECT_LT(batch.Size(), dmat->Info().num_row_);
|
||||
batch_count++;
|
||||
row_count += batch.Size();
|
||||
@@ -61,8 +68,11 @@ TEST(SparsePageDMatrix, MultipleEllpackPages) {
|
||||
}
|
||||
|
||||
TEST(SparsePageDMatrix, RetainEllpackPage) {
|
||||
Context ctx{MakeCUDACtx(0)};
|
||||
auto param = BatchParam{32, tree::TrainParam::DftSparseThreshold()};
|
||||
auto m = CreateSparsePageDMatrix(10000);
|
||||
auto batches = m->GetBatches<EllpackPage>({0, 32});
|
||||
|
||||
auto batches = m->GetBatches<EllpackPage>(&ctx, param);
|
||||
auto begin = batches.begin();
|
||||
auto end = batches.end();
|
||||
|
||||
@@ -87,7 +97,7 @@ TEST(SparsePageDMatrix, RetainEllpackPage) {
|
||||
}
|
||||
|
||||
// make sure it's const and the caller can not modify the content of page.
|
||||
for (auto& page : m->GetBatches<EllpackPage>({0, 32})) {
|
||||
for (auto& page : m->GetBatches<EllpackPage>(&ctx, param)) {
|
||||
static_assert(std::is_const<std::remove_reference_t<decltype(page)>>::value);
|
||||
}
|
||||
|
||||
@@ -98,6 +108,7 @@ TEST(SparsePageDMatrix, RetainEllpackPage) {
|
||||
}
|
||||
|
||||
TEST(SparsePageDMatrix, EllpackPageContent) {
|
||||
auto ctx = CreateEmptyGenericParam(0);
|
||||
constexpr size_t kRows = 6;
|
||||
constexpr size_t kCols = 2;
|
||||
constexpr size_t kPageSize = 1;
|
||||
@@ -110,8 +121,8 @@ TEST(SparsePageDMatrix, EllpackPageContent) {
|
||||
std::unique_ptr<DMatrix>
|
||||
dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
|
||||
|
||||
BatchParam param{0, 2};
|
||||
auto impl = (*dmat->GetBatches<EllpackPage>(param).begin()).Impl();
|
||||
auto param = BatchParam{2, tree::TrainParam::DftSparseThreshold()};
|
||||
auto impl = (*dmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();
|
||||
EXPECT_EQ(impl->base_rowid, 0);
|
||||
EXPECT_EQ(impl->n_rows, kRows);
|
||||
EXPECT_FALSE(impl->is_dense);
|
||||
@@ -120,7 +131,7 @@ TEST(SparsePageDMatrix, EllpackPageContent) {
|
||||
|
||||
std::unique_ptr<EllpackPageImpl> impl_ext;
|
||||
size_t offset = 0;
|
||||
for (auto& batch : dmat_ext->GetBatches<EllpackPage>(param)) {
|
||||
for (auto& batch : dmat_ext->GetBatches<EllpackPage>(&ctx, param)) {
|
||||
if (!impl_ext) {
|
||||
impl_ext.reset(new EllpackPageImpl(
|
||||
batch.Impl()->gidx_buffer.DeviceIdx(), batch.Impl()->Cuts(),
|
||||
@@ -170,8 +181,9 @@ TEST(SparsePageDMatrix, MultipleEllpackPageContent) {
|
||||
std::unique_ptr<DMatrix>
|
||||
dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
|
||||
|
||||
BatchParam param{0, kMaxBins};
|
||||
auto impl = (*dmat->GetBatches<EllpackPage>(param).begin()).Impl();
|
||||
Context ctx{MakeCUDACtx(0)};
|
||||
auto param = BatchParam{kMaxBins, tree::TrainParam::DftSparseThreshold()};
|
||||
auto impl = (*dmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();
|
||||
EXPECT_EQ(impl->base_rowid, 0);
|
||||
EXPECT_EQ(impl->n_rows, kRows);
|
||||
|
||||
@@ -180,7 +192,7 @@ TEST(SparsePageDMatrix, MultipleEllpackPageContent) {
|
||||
thrust::device_vector<bst_float> row_ext_d(kCols);
|
||||
std::vector<bst_float> row(kCols);
|
||||
std::vector<bst_float> row_ext(kCols);
|
||||
for (auto& page : dmat_ext->GetBatches<EllpackPage>(param)) {
|
||||
for (auto& page : dmat_ext->GetBatches<EllpackPage>(&ctx, param)) {
|
||||
auto impl_ext = page.Impl();
|
||||
EXPECT_EQ(impl_ext->base_rowid, current_row);
|
||||
|
||||
@@ -211,10 +223,11 @@ TEST(SparsePageDMatrix, EllpackPageMultipleLoops) {
|
||||
std::unique_ptr<DMatrix>
|
||||
dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
|
||||
|
||||
BatchParam param{0, kMaxBins};
|
||||
Context ctx{MakeCUDACtx(0)};
|
||||
auto param = BatchParam{kMaxBins, tree::TrainParam::DftSparseThreshold()};
|
||||
|
||||
size_t current_row = 0;
|
||||
for (auto& page : dmat_ext->GetBatches<EllpackPage>(param)) {
|
||||
for (auto& page : dmat_ext->GetBatches<EllpackPage>(&ctx, param)) {
|
||||
auto impl_ext = page.Impl();
|
||||
EXPECT_EQ(impl_ext->base_rowid, current_row);
|
||||
current_row += impl_ext->n_rows;
|
||||
|
||||
Reference in New Issue
Block a user