[EM] Enable access to the number of batches. (#10691)

- Expose `NumBatches` in `DMatrix`.
- Small cleanup for removing legacy CUDA stream and ~force CUDA context initialization~.
- Purge old external memory data generation code.
This commit is contained in:
Jiaming Yuan 2024-08-17 02:59:45 +08:00 committed by GitHub
parent 033a666900
commit 8d7fe262d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 169 additions and 352 deletions

View File

@ -541,9 +541,12 @@ class DMatrix {
[[nodiscard]] bool PageExists() const;
/**
* @return Whether the data columns single column block.
* @return Whether the contains a single batch.
*
* The naming is legacy.
*/
[[nodiscard]] virtual bool SingleColBlock() const = 0;
[[nodiscard]] bool SingleColBlock() const { return this->NumBatches() == 1; }
[[nodiscard]] virtual std::int32_t NumBatches() const { return 1; }
virtual ~DMatrix();

View File

@ -486,24 +486,20 @@ class TypedDiscard : public thrust::discard_iterator<T> {
} // namespace detail
template <typename T>
using TypedDiscard =
std::conditional_t<HasThrustMinorVer<12>(), detail::TypedDiscardCTK114<T>,
detail::TypedDiscard<T>>;
using TypedDiscard = std::conditional_t<HasThrustMinorVer<12>(), detail::TypedDiscardCTK114<T>,
detail::TypedDiscard<T>>;
template <typename VectorT, typename T = typename VectorT::value_type,
typename IndexT = typename xgboost::common::Span<T>::index_type>
xgboost::common::Span<T> ToSpan(
VectorT &vec,
IndexT offset = 0,
IndexT size = std::numeric_limits<size_t>::max()) {
typename IndexT = typename xgboost::common::Span<T>::index_type>
xgboost::common::Span<T> ToSpan(VectorT &vec, IndexT offset = 0,
IndexT size = std::numeric_limits<size_t>::max()) {
size = size == std::numeric_limits<size_t>::max() ? vec.size() : size;
CHECK_LE(offset + size, vec.size());
return {vec.data().get() + offset, size};
return {thrust::raw_pointer_cast(vec.data()) + offset, size};
}
template <typename T>
xgboost::common::Span<T> ToSpan(thrust::device_vector<T>& vec,
size_t offset, size_t size) {
xgboost::common::Span<T> ToSpan(thrust::device_vector<T> &vec, size_t offset, size_t size) {
return ToSpan(vec, offset, size);
}
@ -874,13 +870,7 @@ inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT
// Changing this has effect on prediction return, where we need to pass the pointer to
// third-party libraries like cuPy
inline CUDAStreamView DefaultStream() {
#ifdef CUDA_API_PER_THREAD_DEFAULT_STREAM
return CUDAStreamView{cudaStreamPerThread};
#else
return CUDAStreamView{cudaStreamLegacy};
#endif
}
inline CUDAStreamView DefaultStream() { return CUDAStreamView{cudaStreamPerThread}; }
class CUDAStream {
cudaStream_t stream_;

View File

@ -74,6 +74,8 @@ void ExtMemQuantileDMatrix::InitFromCPU(
cpu_impl::GetDataShape(ctx, proxy, *iter, missing, &ext_info);
ext_info.SetInfo(ctx, &this->info_);
this->n_batches_ = ext_info.n_batches;
/**
* Generate quantiles
*/

View File

@ -33,7 +33,7 @@ class ExtMemQuantileDMatrix : public QuantileDMatrix {
std::string cache, bst_bin_t max_bin, bool on_host);
~ExtMemQuantileDMatrix() override;
[[nodiscard]] bool SingleColBlock() const override { return false; }
[[nodiscard]] std::int32_t NumBatches() const override { return n_batches_; }
private:
void InitFromCPU(
@ -63,6 +63,7 @@ class ExtMemQuantileDMatrix : public QuantileDMatrix {
std::string cache_prefix_;
bool on_host_;
BatchParam batch_;
bst_idx_t n_batches_{0};
using EllpackDiskPtr = std::shared_ptr<ExtEllpackPageSource>;
using EllpackHostPtr = std::shared_ptr<ExtEllpackPageHostSource>;

View File

@ -57,8 +57,6 @@ class IterativeDMatrix : public QuantileDMatrix {
BatchSet<EllpackPage> GetEllpackBatches(Context const *ctx, const BatchParam &param) override;
BatchSet<ExtSparsePage> GetExtBatches(Context const *ctx, BatchParam const &param) override;
bool SingleColBlock() const override { return true; }
};
} // namespace data
} // namespace xgboost

View File

@ -94,7 +94,6 @@ class DMatrixProxy : public DMatrix {
MetaInfo const& Info() const override { return info_; }
Context const* Ctx() const override { return &ctx_; }
bool SingleColBlock() const override { return false; }
bool EllpackExists() const override { return false; }
bool GHistIndexExists() const override { return false; }
bool SparsePageExists() const override { return false; }

View File

@ -33,7 +33,6 @@ class SimpleDMatrix : public DMatrix {
const MetaInfo& Info() const override;
Context const* Ctx() const override { return &fmat_ctx_; }
bool SingleColBlock() const override { return true; }
DMatrix* Slice(common::Span<int32_t const> ridxs) override;
DMatrix* SliceCol(int num_slices, int slice_id) override;

View File

@ -90,8 +90,7 @@ class SparsePageDMatrix : public DMatrix {
[[nodiscard]] MetaInfo &Info() override;
[[nodiscard]] const MetaInfo &Info() const override;
[[nodiscard]] Context const *Ctx() const override { return &fmat_ctx_; }
// The only DMatrix implementation that returns false.
[[nodiscard]] bool SingleColBlock() const override { return false; }
[[nodiscard]] std::int32_t NumBatches() const override { return n_batches_; }
DMatrix *Slice(common::Span<std::int32_t const>) override {
LOG(FATAL) << "Slicing DMatrix is not supported for external memory.";
return nullptr;

View File

@ -3,10 +3,10 @@
*/
#include "sparse_page_source.h"
#include <filesystem> // for exists
#include <string> // for string
#include <cstdio> // for remove
#include <filesystem> // for exists
#include <numeric> // for partial_sum
#include <string> // for string
namespace xgboost::data {
void Cache::Commit() {
@ -27,4 +27,8 @@ void TryDeleteCacheFile(const std::string& file) {
<< "; you may want to remove it manually";
}
}
#if !defined(XGBOOST_USE_CUDA)
void InitNewThread::operator()() const { *GlobalConfigThreadLocalStore::Get() = config; }
#endif
} // namespace xgboost::data

View File

@ -18,4 +18,14 @@ void DevicePush(DMatrixProxy *proxy, float missing, SparsePage *page) {
cuda_impl::Dispatch(proxy,
[&](auto const &value) { CopyToSparsePage(value, device, missing, page); });
}
void InitNewThread::operator()() const {
*GlobalConfigThreadLocalStore::Get() = config;
// For CUDA 12.2, we need to force initialize the CUDA context by synchronizing the
// stream when creating a new thread in the thread pool. While for CUDA 11.8, this
// action might cause an insufficient driver version error for some reason. Lastly, it
// should work with CUDA 12.5 without any action being taken.
// dh::DefaultStream().Sync();
}
} // namespace xgboost::data

View File

@ -210,6 +210,12 @@ class DefaultFormatPolicy {
}
};
struct InitNewThread {
GlobalConfiguration config = *GlobalConfigThreadLocalStore::Get();
void operator()() const;
};
/**
* @brief Base class for all page sources. Handles fetching, writing, and iteration.
*
@ -330,10 +336,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S>, public FormatStreamPol
public:
SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, bst_idx_t n_batches,
std::shared_ptr<Cache> cache)
: workers_{std::max(2, std::min(nthreads, 16)),
[config = *GlobalConfigThreadLocalStore::Get()] {
*GlobalConfigThreadLocalStore::Get() = config;
}},
: workers_{std::max(2, std::min(nthreads, 16)), InitNewThread{}},
missing_{missing},
nthreads_{nthreads},
n_features_{n_features},

View File

@ -63,26 +63,27 @@ TEST(SparsePage, PushCSC) {
}
TEST(SparsePage, PushCSCAfterTranspose) {
size_t constexpr kPageSize = 1024, kEntriesPerCol = 3;
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(kEntries);
bst_idx_t constexpr kRows = 1024, kCols = 21;
auto dmat =
RandomDataGenerator{kRows, kCols, 0.0f}.Batches(4).GenerateSparsePageDMatrix("temp", true);
const int ncols = dmat->Info().num_col_;
SparsePage page; // Consolidated sparse page
for (const auto &batch : dmat->GetBatches<xgboost::SparsePage>()) {
SparsePage page; // Consolidated sparse page
for (const auto& batch : dmat->GetBatches<xgboost::SparsePage>()) {
// Transpose each batch and push
SparsePage tmp = batch.GetTranspose(ncols, AllThreadsForTest());
page.PushCSC(tmp);
}
// Make sure that the final sparse page has the right number of entries
ASSERT_EQ(kEntries, page.data.Size());
ASSERT_EQ(kRows * kCols, page.data.Size());
page.SortRows(AllThreadsForTest());
auto v = page.GetView();
for (size_t i = 0; i < v.Size(); ++i) {
auto column = v[i];
for (size_t j = 1; j < column.size(); ++j) {
ASSERT_GE(column[j].fvalue, column[j-1].fvalue);
ASSERT_GE(column[j].fvalue, column[j - 1].fvalue);
}
}
}

View File

@ -140,13 +140,11 @@ struct ReadRowFunction {
TEST(EllpackPage, Copy) {
constexpr size_t kRows = 1024;
constexpr size_t kCols = 16;
constexpr size_t kPageSize = 1024;
// Create a DMatrix with multiple batches.
dmlc::TemporaryDirectory tmpdir;
std::unique_ptr<DMatrix>
dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
Context ctx{MakeCUDACtx(0)};
auto dmat =
RandomDataGenerator{kRows, kCols, 0.0f}.Batches(4).GenerateSparsePageDMatrix("temp", true);
auto ctx = MakeCUDACtx(0);
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
auto page = (*dmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();
@ -187,14 +185,12 @@ TEST(EllpackPage, Copy) {
TEST(EllpackPage, Compact) {
constexpr size_t kRows = 16;
constexpr size_t kCols = 2;
constexpr size_t kPageSize = 1;
constexpr size_t kCompactedRows = 8;
// Create a DMatrix with multiple batches.
dmlc::TemporaryDirectory tmpdir;
std::unique_ptr<DMatrix> dmat(
CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
Context ctx{MakeCUDACtx(0)};
auto dmat =
RandomDataGenerator{kRows, kCols, 0.0f}.Batches(2).GenerateSparsePageDMatrix("temp", true);
auto ctx = MakeCUDACtx(0);
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
auto page = (*dmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();

View File

@ -214,15 +214,15 @@ TEST(SparsePageDMatrix, MetaInfo) {
}
TEST(SparsePageDMatrix, RowAccess) {
std::unique_ptr<xgboost::DMatrix> dmat = xgboost::CreateSparsePageDMatrix(24);
auto dmat = RandomDataGenerator{12, 6, 0.8f}.Batches(2).GenerateSparsePageDMatrix("temp", false);
// Test the data read into the first row
auto &batch = *dmat->GetBatches<xgboost::SparsePage>().begin();
auto page = batch.GetView();
auto first_row = page[0];
ASSERT_EQ(first_row.size(), 3ul);
EXPECT_EQ(first_row[2].index, 2u);
EXPECT_NEAR(first_row[2].fvalue, 0.986566, 1e-4);
ASSERT_EQ(first_row.size(), 1ul);
EXPECT_EQ(first_row[0].index, 5u);
EXPECT_NEAR(first_row[0].fvalue, 0.1805125, 1e-4);
}
TEST(SparsePageDMatrix, ColAccess) {
@ -268,11 +268,10 @@ TEST(SparsePageDMatrix, ColAccess) {
}
TEST(SparsePageDMatrix, ThreadSafetyException) {
size_t constexpr kEntriesPerCol = 3;
size_t constexpr kEntries = 64 * kEntriesPerCol * 2;
Context ctx;
std::unique_ptr<xgboost::DMatrix> dmat = xgboost::CreateSparsePageDMatrix(kEntries);
auto dmat =
RandomDataGenerator{4096, 12, 0.0f}.Batches(8).GenerateSparsePageDMatrix("temp", true);
int threads = 1000;
@ -304,10 +303,9 @@ TEST(SparsePageDMatrix, ThreadSafetyException) {
// Multi-batches access
TEST(SparsePageDMatrix, ColAccessBatches) {
size_t constexpr kPageSize = 1024, kEntriesPerCol = 3;
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
// Create multiple sparse pages
std::unique_ptr<xgboost::DMatrix> dmat{xgboost::CreateSparsePageDMatrix(kEntries)};
auto dmat =
RandomDataGenerator{1024, 32, 0.4f}.Batches(3).GenerateSparsePageDMatrix("temp", true);
ASSERT_EQ(dmat->Ctx()->Threads(), AllThreadsForTest());
Context ctx;
for (auto const &page : dmat->GetBatches<xgboost::CSCPage>(&ctx)) {

View File

@ -115,13 +115,10 @@ TEST(SparsePageDMatrix, EllpackSkipSparsePage) {
}
TEST(SparsePageDMatrix, MultipleEllpackPages) {
Context ctx{MakeCUDACtx(0)};
auto ctx = MakeCUDACtx(0);
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
dmlc::TemporaryDirectory tmpdir;
std::string filename = tmpdir.path + "/big.libsvm";
size_t constexpr kPageSize = 64, kEntriesPerCol = 3;
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(kEntries, filename);
auto dmat = RandomDataGenerator{1024, 2, 0.5f}.Batches(2).GenerateSparsePageDMatrix("temp", true);
// Loop over the batches and count the records
std::int64_t batch_count = 0;
@ -135,15 +132,13 @@ TEST(SparsePageDMatrix, MultipleEllpackPages) {
EXPECT_EQ(row_count, dmat->Info().num_row_);
auto path =
data::MakeId(filename,
dynamic_cast<data::SparsePageDMatrix *>(dmat.get())) +
".ellpack.page";
data::MakeId("tmep", dynamic_cast<data::SparsePageDMatrix*>(dmat.get())) + ".ellpack.page";
}
TEST(SparsePageDMatrix, RetainEllpackPage) {
Context ctx{MakeCUDACtx(0)};
auto ctx = MakeCUDACtx(0);
auto param = BatchParam{32, tree::TrainParam::DftSparseThreshold()};
auto m = CreateSparsePageDMatrix(10000);
auto m = RandomDataGenerator{2048, 4, 0.0f}.Batches(8).GenerateSparsePageDMatrix("temp", true);
auto batches = m->GetBatches<EllpackPage>(&ctx, param);
auto begin = batches.begin();
@ -278,20 +273,19 @@ struct ReadRowFunction {
};
TEST(SparsePageDMatrix, MultipleEllpackPageContent) {
constexpr size_t kRows = 6;
constexpr size_t kRows = 16;
constexpr size_t kCols = 2;
constexpr int kMaxBins = 256;
constexpr size_t kPageSize = 1;
// Create an in-memory DMatrix.
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true));
auto dmat =
RandomDataGenerator{kRows, kCols, 0.0f}.Batches(1).GenerateSparsePageDMatrix("temp", true);
// Create a DMatrix with multiple batches.
dmlc::TemporaryDirectory tmpdir;
std::unique_ptr<DMatrix>
dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
auto dmat_ext =
RandomDataGenerator{kRows, kCols, 0.0f}.Batches(2).GenerateSparsePageDMatrix("temp", true);
Context ctx{MakeCUDACtx(0)};
auto ctx = MakeCUDACtx(0);
auto param = BatchParam{kMaxBins, tree::TrainParam::DftSparseThreshold()};
auto impl = (*dmat->GetBatches<EllpackPage>(&ctx, param).begin()).Impl();
EXPECT_EQ(impl->base_rowid, 0);
@ -325,17 +319,16 @@ TEST(SparsePageDMatrix, EllpackPageMultipleLoops) {
constexpr size_t kRows = 1024;
constexpr size_t kCols = 16;
constexpr int kMaxBins = 256;
constexpr size_t kPageSize = 4096;
// Create an in-memory DMatrix.
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true));
auto dmat =
RandomDataGenerator{kRows, kCols, 0.0f}.Batches(1).GenerateSparsePageDMatrix("temp", true);
// Create a DMatrix with multiple batches.
dmlc::TemporaryDirectory tmpdir;
std::unique_ptr<DMatrix>
dmat_ext(CreateSparsePageDMatrixWithRC(kRows, kCols, kPageSize, true, tmpdir));
auto dmat_ext =
RandomDataGenerator{kRows, kCols, 0.0f}.Batches(8).GenerateSparsePageDMatrix("temp", true);
Context ctx{MakeCUDACtx(0)};
auto ctx = MakeCUDACtx(0);
auto param = BatchParam{kMaxBins, tree::TrainParam::DftSparseThreshold()};
size_t current_row = 0;

View File

@ -715,7 +715,7 @@ TEST(GBTree, InplacePredictionError) {
p_fmat = rng.GenerateQuantileDMatrix(true);
} else {
#if defined(XGBOOST_USE_CUDA)
p_fmat = rng.GenerateDeviceDMatrix(true);
p_fmat = rng.Device(ctx->Device()).GenerateQuantileDMatrix(true);
#else
CHECK(p_fmat);
#endif // defined(XGBOOST_USE_CUDA)

View File

@ -13,7 +13,6 @@
#include <algorithm>
#include <limits> // for numeric_limits
#include <random>
#include "../../src/collective/communicator-inl.h" // for GetRank
#include "../../src/data/adapter.h"
@ -21,8 +20,6 @@
#include "../../src/data/simple_dmatrix.h"
#include "../../src/data/sparse_page_dmatrix.h"
#include "../../src/gbm/gbtree_model.h"
#include "../../src/tree/param.h" // for TrainParam
#include "filesystem.h" // dmlc::TemporaryDirectory
#include "xgboost/c_api.h"
#include "xgboost/predictor.h"
@ -456,6 +453,7 @@ void RandomDataGenerator::GenerateCSR(
}
EXPECT_EQ(batch_count, n_batches_);
EXPECT_EQ(dmat->NumBatches(), n_batches_);
EXPECT_EQ(row_count, dmat->Info().num_row_);
if (with_label) {
@ -503,13 +501,24 @@ void RandomDataGenerator::GenerateCSR(
}
std::shared_ptr<DMatrix> RandomDataGenerator::GenerateQuantileDMatrix(bool with_label) {
NumpyArrayIterForTest iter{this->sparsity_, this->rows_, this->cols_, 1};
auto m = std::make_shared<data::IterativeDMatrix>(
&iter, iter.Proxy(), nullptr, Reset, Next, std::numeric_limits<float>::quiet_NaN(), 0, bins_);
if (with_label) {
this->GenerateLabels(m);
std::shared_ptr<data::IterativeDMatrix> p_fmat;
if (this->device_.IsCPU()) {
NumpyArrayIterForTest iter{this->sparsity_, this->rows_, this->cols_, 1};
p_fmat =
std::make_shared<data::IterativeDMatrix>(&iter, iter.Proxy(), nullptr, Reset, Next,
std::numeric_limits<float>::quiet_NaN(), 0, bins_);
} else {
CudaArrayIterForTest iter{this->sparsity_, this->rows_, this->cols_, 1};
p_fmat =
std::make_shared<data::IterativeDMatrix>(&iter, iter.Proxy(), nullptr, Reset, Next,
std::numeric_limits<float>::quiet_NaN(), 0, bins_);
}
return m;
if (with_label) {
this->GenerateLabels(p_fmat);
}
return p_fmat;
}
#if !defined(XGBOOST_USE_CUDA)
@ -551,125 +560,6 @@ std::shared_ptr<DMatrix> GetDMatrixFromData(const std::vector<float>& x, std::si
return p_fmat;
}
std::unique_ptr<DMatrix> CreateSparsePageDMatrix(bst_idx_t n_samples, bst_feature_t n_features,
size_t n_batches, std::string prefix) {
CHECK_GE(n_samples, n_batches);
NumpyArrayIterForTest iter(0, n_samples, n_features, n_batches);
std::unique_ptr<DMatrix> dmat{DMatrix::Create(
static_cast<DataIterHandle>(&iter), iter.Proxy(), Reset, Next,
std::numeric_limits<float>::quiet_NaN(), omp_get_max_threads(), prefix, false)};
auto row_page_path =
data::MakeId(prefix, dynamic_cast<data::SparsePageDMatrix*>(dmat.get())) + ".row.page";
EXPECT_TRUE(FileExists(row_page_path)) << row_page_path;
// Loop over the batches and count the number of pages
int64_t batch_count = 0;
int64_t row_count = 0;
for (const auto& batch : dmat->GetBatches<xgboost::SparsePage>()) {
batch_count++;
row_count += batch.Size();
}
EXPECT_GE(batch_count, n_batches);
EXPECT_EQ(row_count, dmat->Info().num_row_);
return dmat;
}
std::unique_ptr<DMatrix> CreateSparsePageDMatrix(size_t n_entries,
std::string prefix) {
size_t n_columns = 3;
size_t n_rows = n_entries / n_columns;
NumpyArrayIterForTest iter(0, n_rows, n_columns, 2);
std::unique_ptr<DMatrix> dmat{
DMatrix::Create(static_cast<DataIterHandle>(&iter), iter.Proxy(), Reset, Next,
std::numeric_limits<float>::quiet_NaN(), 0, prefix, false)};
auto row_page_path =
data::MakeId(prefix,
dynamic_cast<data::SparsePageDMatrix *>(dmat.get())) +
".row.page";
EXPECT_TRUE(FileExists(row_page_path)) << row_page_path;
// Loop over the batches and count the records
int64_t batch_count = 0;
int64_t row_count = 0;
for (const auto &batch : dmat->GetBatches<xgboost::SparsePage>()) {
batch_count++;
row_count += batch.Size();
}
EXPECT_GE(batch_count, 2);
EXPECT_EQ(row_count, dmat->Info().num_row_);
return dmat;
}
std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(size_t n_rows, size_t n_cols,
size_t page_size, bool deterministic,
const dmlc::TemporaryDirectory& tempdir) {
if (!n_rows || !n_cols) {
return nullptr;
}
// Create the svm file in a temp dir
const std::string tmp_file = tempdir.path + "/big.libsvm";
std::ofstream fo(tmp_file.c_str());
size_t cols_per_row = ((std::max(n_rows, n_cols) - 1) / std::min(n_rows, n_cols)) + 1;
int64_t rem_cols = n_cols;
size_t col_idx = 0;
// Random feature id generator
std::random_device rdev;
std::unique_ptr<std::mt19937> gen;
if (deterministic) {
// Seed it with a constant value for this configuration - without getting too fancy
// like ordered pairing functions and its likes to make it truely unique
gen.reset(new std::mt19937(n_rows * n_cols));
} else {
gen.reset(new std::mt19937(rdev()));
}
std::uniform_int_distribution<size_t> label(0, 1);
std::uniform_int_distribution<size_t> dis(1, n_cols);
for (size_t i = 0; i < n_rows; ++i) {
// Make sure that all cols are slotted in the first few rows; randomly distribute the
// rest
std::stringstream row_data;
size_t j = 0;
if (rem_cols > 0) {
for (; j < std::min(static_cast<size_t>(rem_cols), cols_per_row); ++j) {
row_data << label(*gen) << " " << (col_idx + j) << ":"
<< (col_idx + j + 1) * 10 * i;
}
rem_cols -= cols_per_row;
} else {
// Take some random number of colums in [1, n_cols] and slot them here
std::vector<size_t> random_columns;
size_t ncols = dis(*gen);
for (; j < ncols; ++j) {
size_t fid = (col_idx + j) % n_cols;
random_columns.push_back(fid);
}
std::sort(random_columns.begin(), random_columns.end());
for (auto fid : random_columns) {
row_data << label(*gen) << " " << fid << ":" << (fid + 1) * 10 * i;
}
}
col_idx += j;
fo << row_data.str() << "\n";
}
fo.close();
std::string uri = tmp_file + "?format=libsvm";
if (page_size > 0) {
uri += "#" + tmp_file + ".cache";
}
std::unique_ptr<DMatrix> dmat(DMatrix::Load(uri));
return dmat;
}
std::unique_ptr<GradientBooster> CreateTrainedGBM(std::string name, Args kwargs, size_t kRows,
size_t kCols,
LearnerModelParam const* learner_model_param,

View File

@ -3,12 +3,9 @@
*/
#include <xgboost/c_api.h>
#include "../../src/data/device_adapter.cuh"
#include "../../src/data/iterative_dmatrix.h"
#include "helpers.h"
namespace xgboost {
CudaArrayIterForTest::CudaArrayIterForTest(float sparsity, size_t rows,
size_t cols, size_t batches)
: ArrayIterForTest{sparsity, rows, cols, batches} {
@ -26,14 +23,4 @@ int CudaArrayIterForTest::Next() {
iter_++;
return 1;
}
std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDeviceDMatrix(bool with_label) {
CudaArrayIterForTest iter{this->sparsity_, this->rows_, this->cols_, 1};
auto m = std::make_shared<data::IterativeDMatrix>(
&iter, iter.Proxy(), nullptr, Reset, Next, std::numeric_limits<float>::quiet_NaN(), 0, bins_);
if (with_label) {
this->GenerateLabels(m);
}
return m;
}
} // namespace xgboost

View File

@ -324,9 +324,6 @@ class RandomDataGenerator {
[[nodiscard]] std::shared_ptr<DMatrix> GenerateExtMemQuantileDMatrix(std::string prefix,
bool with_label) const;
#if defined(XGBOOST_USE_CUDA)
std::shared_ptr<DMatrix> GenerateDeviceDMatrix(bool with_label);
#endif
std::shared_ptr<DMatrix> GenerateQuantileDMatrix(bool with_label);
};
@ -350,45 +347,6 @@ inline std::vector<float> GenerateRandomCategoricalSingleColumn(int n, size_t nu
std::shared_ptr<DMatrix> GetDMatrixFromData(const std::vector<float>& x, std::size_t num_rows,
bst_feature_t num_columns);
/**
* \brief Create Sparse Page using data iterator.
*
* \param n_samples Total number of rows for all batches combined.
* \param n_features Number of features
* \param n_batches Number of batches
* \param prefix Cache prefix, can be used for specifying file path.
*
* \return A Sparse DMatrix with n_batches.
*/
std::unique_ptr<DMatrix> CreateSparsePageDMatrix(bst_idx_t n_samples, bst_feature_t n_features,
size_t n_batches, std::string prefix = "cache");
/**
* Deprecated, stop using it
*/
std::unique_ptr<DMatrix> CreateSparsePageDMatrix(size_t n_entries, std::string prefix = "cache");
/**
* Deprecated, stop using it
*
* \brief Creates dmatrix with some records, each record containing random number of
* features in [1, n_cols]
*
* \param n_rows Number of records to create.
* \param n_cols Max number of features within that record.
* \param page_size Sparse page size for the pages within the dmatrix. If page size is 0
* then the entire dmatrix is resident in memory; else, multiple sparse pages
* of page size are created and backed to disk, which would have to be
* streamed in at point of use.
* \param deterministic The content inside the dmatrix is constant for this configuration, if true;
* else, the content changes every time this method is invoked
*
* \return The new dmatrix.
*/
std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(
size_t n_rows, size_t n_cols, size_t page_size, bool deterministic,
const dmlc::TemporaryDirectory& tempdir = dmlc::TemporaryDirectory());
std::unique_ptr<GradientBooster> CreateTrainedGBM(std::string name, Args kwargs, size_t kRows,
size_t kCols,
LearnerModelParam const* learner_model_param,

View File

@ -36,9 +36,10 @@ TEST(SyclPredictor, ExternalMemory) {
Context ctx;
ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});
size_t constexpr kPageSize = 64, kEntriesPerCol = 3;
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(kEntries);
bst_idx_t constexpr kRows{64};
bst_feature_t constexpr kCols{12};
auto dmat =
RandomDataGenerator{kRows, kCols, 0.5f}.Batches(3).GenerateSparsePageDMatrix("temp", true);
TestBasic(dmat.get(), &ctx);
}

View File

@ -10,12 +10,10 @@
#include "../../../src/gbm/gbtree.h"
#include "../../../src/gbm/gbtree_model.h"
#include "../collective/test_worker.h" // for TestDistributedGlobal
#include "../filesystem.h" // dmlc::TemporaryDirectory
#include "../helpers.h"
#include "test_predictor.h"
namespace xgboost {
TEST(CpuPredictor, Basic) {
Context ctx;
size_t constexpr kRows = 5;
@ -56,9 +54,10 @@ TEST(CpuPredictor, IterationRangeColmnSplit) {
TEST(CpuPredictor, ExternalMemory) {
Context ctx;
size_t constexpr kPageSize = 64, kEntriesPerCol = 3;
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(kEntries);
bst_idx_t constexpr kRows{64};
bst_feature_t constexpr kCols{12};
auto dmat =
RandomDataGenerator{kRows, kCols, 0.5f}.Batches(3).GenerateSparsePageDMatrix("temp", true);
TestBasic(dmat.get(), &ctx);
}

View File

@ -123,8 +123,8 @@ TEST(GPUPredictor, EllpackBasic) {
size_t rows = bins * 16;
auto p_m = RandomDataGenerator{rows, kCols, 0.0}
.Bins(bins)
.Device(DeviceOrd::CUDA(0))
.GenerateDeviceDMatrix(false);
.Device(ctx.Device())
.GenerateQuantileDMatrix(false);
ASSERT_FALSE(p_m->PageExists<SparsePage>());
TestPredictionFromGradientIndex<EllpackPage>(&ctx, rows, kCols, p_m);
TestPredictionFromGradientIndex<EllpackPage>(&ctx, bins, kCols, p_m);
@ -137,7 +137,7 @@ TEST(GPUPredictor, EllpackTraining) {
auto p_ellpack = RandomDataGenerator{kRows, kCols, 0.0}
.Bins(kBins)
.Device(ctx.Device())
.GenerateDeviceDMatrix(false);
.GenerateQuantileDMatrix(false);
HostDeviceVector<float> storage(kRows * kCols);
auto columnar =
RandomDataGenerator{kRows, kCols, 0.0}.Device(ctx.Device()).GenerateArrayInterface(&storage);

View File

@ -117,24 +117,15 @@ TEST(Learner, CheckGroup) {
EXPECT_ANY_THROW(learner->UpdateOneIter(0, p_mat));
}
TEST(Learner, SLOW_CheckMultiBatch) { // NOLINT
// Create sufficiently large data to make two row pages
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/big.libsvm";
CreateBigTestData(tmp_file, 50000);
std::shared_ptr<DMatrix> dmat(
xgboost::DMatrix::Load(tmp_file + "?format=libsvm" + "#" + tmp_file + ".cache"));
EXPECT_FALSE(dmat->SingleColBlock());
size_t num_row = dmat->Info().num_row_;
std::vector<bst_float> labels(num_row);
for (size_t i = 0; i < num_row; ++i) {
labels[i] = i % 2;
}
dmat->SetInfo("label", Make1dInterfaceTest(labels.data(), num_row));
std::vector<std::shared_ptr<DMatrix>> mat{dmat};
TEST(Learner, CheckMultiBatch) {
auto p_fmat =
RandomDataGenerator{512, 128, 0.8}.Batches(4).GenerateSparsePageDMatrix("temp", true);
ASSERT_FALSE(p_fmat->SingleColBlock());
std::vector<std::shared_ptr<DMatrix>> mat{p_fmat};
auto learner = std::unique_ptr<Learner>(Learner::Create(mat));
learner->SetParams(Args{{"objective", "binary:logistic"}});
learner->UpdateOneIter(0, dmat);
learner->UpdateOneIter(0, p_fmat);
}
TEST(Learner, Configuration) {

View File

@ -7,22 +7,18 @@
#include "../../../../src/tree/gpu_hist/gradient_based_sampler.cuh"
#include "../../../../src/tree/param.h"
#include "../../../../src/tree/param.h" // TrainParam
#include "../../filesystem.h" // dmlc::TemporaryDirectory
#include "../../helpers.h"
namespace xgboost::tree {
void VerifySampling(size_t page_size,
float subsample,
int sampling_method,
bool fixed_size_sampling = true,
bool check_sum = true) {
void VerifySampling(size_t page_size, float subsample, int sampling_method,
bool fixed_size_sampling = true, bool check_sum = true) {
constexpr size_t kRows = 4096;
constexpr size_t kCols = 1;
size_t sample_rows = kRows * subsample;
bst_idx_t sample_rows = kRows * subsample;
bst_idx_t n_batches = fixed_size_sampling ? 1 : 4;
dmlc::TemporaryDirectory tmpdir;
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrix(
kRows, kCols, kRows / (page_size == 0 ? kRows : page_size), tmpdir.path + "/cache"));
auto dmat = RandomDataGenerator{kRows, kCols, 0.0f}.Batches(n_batches).GenerateSparsePageDMatrix(
"temp", true);
auto gpair = GenerateRandomGradients(kRows);
GradientPair sum_gpair{};
for (const auto& gp : gpair.ConstHostVector()) {
@ -78,14 +74,12 @@ TEST(GradientBasedSampler, NoSamplingExternalMemory) {
constexpr size_t kRows = 2048;
constexpr size_t kCols = 1;
constexpr float kSubsample = 1.0f;
constexpr size_t kPageSize = 1024;
// Create a DMatrix with multiple batches.
dmlc::TemporaryDirectory tmpdir;
std::unique_ptr<DMatrix> dmat(
CreateSparsePageDMatrix(kRows, kCols, kRows / kPageSize, tmpdir.path + "/cache"));
auto dmat =
RandomDataGenerator{kRows, kCols, 0.0f}.Batches(4).GenerateSparsePageDMatrix("temp", true);
auto gpair = GenerateRandomGradients(kRows);
Context ctx{MakeCUDACtx(0)};
auto ctx = MakeCUDACtx(0);
gpair.SetDevice(ctx.Device());
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};

View File

@ -406,7 +406,8 @@ namespace {
void TestHistogramExternalMemory(Context const *ctx, BatchParam batch_param, bool is_approx,
bool force_read_by_column) {
size_t constexpr kEntries = 1 << 16;
auto m = CreateSparsePageDMatrix(kEntries, "cache");
auto m =
RandomDataGenerator{kEntries / 8, 8, 0.0f}.Batches(4).GenerateSparsePageDMatrix("temp", true);
std::vector<float> hess(m->Info().num_row_, 1.0);
if (is_approx) {

View File

@ -17,12 +17,11 @@
#include "../../../src/common/random.h" // for GlobalRandom
#include "../../../src/tree/param.h" // for TrainParam
#include "../collective/test_worker.h" // for BaseMGPUTest
#include "../filesystem.h" // dmlc::TemporaryDirectory
#include "../helpers.h"
namespace xgboost::tree {
namespace {
void UpdateTree(Context const* ctx, linalg::Matrix<GradientPair>* gpair, DMatrix* dmat, bool is_ext,
void UpdateTree(Context const* ctx, linalg::Matrix<GradientPair>* gpair, DMatrix* dmat,
RegTree* tree, HostDeviceVector<bst_float>* preds, float subsample,
const std::string& sampling_method, bst_bin_t max_bin) {
Args args{
@ -45,7 +44,7 @@ void UpdateTree(Context const* ctx, linalg::Matrix<GradientPair>* gpair, DMatrix
hist_maker->Update(&param, gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position},
{tree});
auto cache = linalg::MakeTensorView(ctx, preds->DeviceSpan(), preds->Size(), 1);
if (subsample < 1.0 && is_ext) {
if (subsample < 1.0 && !dmat->SingleColBlock()) {
ASSERT_FALSE(hist_maker->UpdatePredictionCache(dmat, cache));
} else {
ASSERT_TRUE(hist_maker->UpdatePredictionCache(dmat, cache));
@ -58,22 +57,23 @@ TEST(GpuHist, UniformSampling) {
constexpr size_t kCols = 2;
constexpr float kSubsample = 0.9999;
common::GlobalRandom().seed(1994);
auto ctx = MakeCUDACtx(0);
// Create an in-memory DMatrix.
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true));
auto p_fmat = RandomDataGenerator{kRows, kCols, 0.0f}.GenerateDMatrix(true);
ASSERT_TRUE(p_fmat->SingleColBlock());
linalg::Matrix<GradientPair> gpair({kRows}, Context{}.MakeCUDA().Device());
linalg::Matrix<GradientPair> gpair({kRows}, ctx.Device());
gpair.Data()->Copy(GenerateRandomGradients(kRows));
// Build a tree using the in-memory DMatrix.
RegTree tree;
HostDeviceVector<bst_float> preds(kRows, 0.0, DeviceOrd::CUDA(0));
Context ctx(MakeCUDACtx(0));
UpdateTree(&ctx, &gpair, dmat.get(), false, &tree, &preds, 1.0, "uniform", kRows);
HostDeviceVector<bst_float> preds(kRows, 0.0, ctx.Device());
UpdateTree(&ctx, &gpair, p_fmat.get(), &tree, &preds, 1.0, "uniform", kRows);
// Build another tree using sampling.
RegTree tree_sampling;
HostDeviceVector<bst_float> preds_sampling(kRows, 0.0, DeviceOrd::CUDA(0));
UpdateTree(&ctx, &gpair, dmat.get(), false, &tree_sampling, &preds_sampling, kSubsample, "uniform",
HostDeviceVector<bst_float> preds_sampling(kRows, 0.0, ctx.Device());
UpdateTree(&ctx, &gpair, p_fmat.get(), &tree_sampling, &preds_sampling, kSubsample, "uniform",
kRows);
// Make sure the predictions are the same.
@ -89,23 +89,23 @@ TEST(GpuHist, GradientBasedSampling) {
constexpr size_t kCols = 2;
constexpr float kSubsample = 0.9999;
common::GlobalRandom().seed(1994);
auto ctx = MakeCUDACtx(0);
// Create an in-memory DMatrix.
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrixWithRC(kRows, kCols, 0, true));
auto p_fmat = RandomDataGenerator{kRows, kCols, 0.0f}.GenerateDMatrix(true);
linalg::Matrix<GradientPair> gpair({kRows}, MakeCUDACtx(0).Device());
linalg::Matrix<GradientPair> gpair({kRows}, ctx.Device());
gpair.Data()->Copy(GenerateRandomGradients(kRows));
// Build a tree using the in-memory DMatrix.
RegTree tree;
HostDeviceVector<bst_float> preds(kRows, 0.0, DeviceOrd::CUDA(0));
Context ctx(MakeCUDACtx(0));
UpdateTree(&ctx, &gpair, dmat.get(), false, &tree, &preds, 1.0, "uniform", kRows);
HostDeviceVector<bst_float> preds(kRows, 0.0, ctx.Device());
UpdateTree(&ctx, &gpair, p_fmat.get(), &tree, &preds, 1.0, "uniform", kRows);
// Build another tree using sampling.
RegTree tree_sampling;
HostDeviceVector<bst_float> preds_sampling(kRows, 0.0, DeviceOrd::CUDA(0));
UpdateTree(&ctx, &gpair, dmat.get(), false, &tree_sampling, &preds_sampling, kSubsample,
HostDeviceVector<bst_float> preds_sampling(kRows, 0.0, ctx.Device());
UpdateTree(&ctx, &gpair, p_fmat.get(), &tree_sampling, &preds_sampling, kSubsample,
"gradient_based", kRows);
// Make sure the predictions are the same.
@ -119,29 +119,29 @@ TEST(GpuHist, GradientBasedSampling) {
TEST(GpuHist, ExternalMemory) {
constexpr size_t kRows = 4096;
constexpr size_t kCols = 2;
constexpr size_t kPageSize = 1024;
dmlc::TemporaryDirectory tmpdir;
// Create a DMatrix with multiple batches.
std::unique_ptr<DMatrix> dmat_ext(
CreateSparsePageDMatrix(kRows, kCols, kRows / kPageSize, tmpdir.path + "/cache"));
auto p_fmat_ext =
RandomDataGenerator{kRows, kCols, 0.0f}.Batches(4).GenerateSparsePageDMatrix("temp", true);
ASSERT_FALSE(p_fmat_ext->SingleColBlock());
// Create a single batch DMatrix.
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrix(kRows, kCols, 1, tmpdir.path + "/cache"));
auto p_fmat =
RandomDataGenerator{kRows, kCols, 0.0f}.Batches(1).GenerateSparsePageDMatrix("temp", true);
ASSERT_TRUE(p_fmat->SingleColBlock());
Context ctx(MakeCUDACtx(0));
auto ctx = MakeCUDACtx(0);
linalg::Matrix<GradientPair> gpair({kRows}, ctx.Device());
gpair.Data()->Copy(GenerateRandomGradients(kRows));
// Build a tree using the in-memory DMatrix.
RegTree tree;
HostDeviceVector<bst_float> preds(kRows, 0.0, DeviceOrd::CUDA(0));
UpdateTree(&ctx, &gpair, dmat.get(), false, &tree, &preds, 1.0, "uniform", kRows);
HostDeviceVector<bst_float> preds(kRows, 0.0, ctx.Device());
UpdateTree(&ctx, &gpair, p_fmat.get(), &tree, &preds, 1.0, "uniform", kRows);
// Build another tree using multiple ELLPACK pages.
RegTree tree_ext;
HostDeviceVector<bst_float> preds_ext(kRows, 0.0, DeviceOrd::CUDA(0));
UpdateTree(&ctx, &gpair, dmat_ext.get(), true, &tree_ext, &preds_ext, 1.0, "uniform", kRows);
HostDeviceVector<bst_float> preds_ext(kRows, 0.0, ctx.Device());
UpdateTree(&ctx, &gpair, p_fmat_ext.get(), &tree_ext, &preds_ext, 1.0, "uniform", kRows);
// Make sure the predictions are the same.
auto preds_h = preds.ConstHostVector();
@ -157,20 +157,21 @@ TEST(GpuHist, ExternalMemoryWithSampling) {
const std::string kSamplingMethod = "gradient_based";
common::GlobalRandom().seed(0);
dmlc::TemporaryDirectory tmpdir;
Context ctx(MakeCUDACtx(0));
auto ctx = MakeCUDACtx(0);
// Create a single batch DMatrix.
auto p_fmat = RandomDataGenerator{kRows, kCols, 0.0f}
.Device(ctx.Device())
.Batches(1)
.GenerateSparsePageDMatrix("temp", true);
ASSERT_TRUE(p_fmat->SingleColBlock());
// Create a DMatrix with multiple batches.
auto p_fmat_ext = RandomDataGenerator{kRows, kCols, 0.0f}
.Device(ctx.Device())
.Batches(4)
.GenerateSparsePageDMatrix("temp", true);
ASSERT_FALSE(p_fmat_ext->SingleColBlock());
linalg::Matrix<GradientPair> gpair({kRows}, ctx.Device());
gpair.Data()->Copy(GenerateRandomGradients(kRows));
@ -179,26 +180,25 @@ TEST(GpuHist, ExternalMemoryWithSampling) {
auto rng = common::GlobalRandom();
RegTree tree;
HostDeviceVector<bst_float> preds(kRows, 0.0, DeviceOrd::CUDA(0));
UpdateTree(&ctx, &gpair, p_fmat.get(), true, &tree, &preds, kSubsample, kSamplingMethod, kRows);
HostDeviceVector<bst_float> preds(kRows, 0.0, ctx.Device());
UpdateTree(&ctx, &gpair, p_fmat.get(), &tree, &preds, kSubsample, kSamplingMethod, kRows);
// Build another tree using multiple ELLPACK pages.
common::GlobalRandom() = rng;
RegTree tree_ext;
HostDeviceVector<bst_float> preds_ext(kRows, 0.0, DeviceOrd::CUDA(0));
UpdateTree(&ctx, &gpair, p_fmat_ext.get(), true, &tree_ext, &preds_ext, kSubsample,
kSamplingMethod, kRows);
HostDeviceVector<bst_float> preds_ext(kRows, 0.0, ctx.Device());
UpdateTree(&ctx, &gpair, p_fmat_ext.get(), &tree_ext, &preds_ext, kSubsample, kSamplingMethod,
kRows);
// Make sure the predictions are the same.
auto preds_h = preds.ConstHostVector();
auto preds_ext_h = preds_ext.ConstHostVector();
for (size_t i = 0; i < kRows; i++) {
ASSERT_NEAR(preds_h[i], preds_ext_h[i], 1e-3);
}
Json jtree{Object{}};
Json jtree_ext{Object{}};
tree.SaveModel(&jtree);
tree_ext.SaveModel(&jtree_ext);
ASSERT_EQ(jtree, jtree_ext);
}
TEST(GpuHist, ConfigIO) {
Context ctx(MakeCUDACtx(0));
auto ctx = MakeCUDACtx(0);
ObjInfo task{ObjInfo::kRegression};
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_gpu_hist", &ctx, &task)};
updater->Configure(Args{});