[EM] Enable access to the number of batches. (#10691)

- Expose `NumBatches` in `DMatrix`.
- Small cleanup for removing legacy CUDA stream and ~force CUDA context initialization~.
- Purge old external memory data generation code.
This commit is contained in:
Jiaming Yuan
2024-08-17 02:59:45 +08:00
committed by GitHub
parent 033a666900
commit 8d7fe262d9
26 changed files with 169 additions and 352 deletions

View File

@@ -10,12 +10,10 @@
#include "../../../src/gbm/gbtree.h"
#include "../../../src/gbm/gbtree_model.h"
#include "../collective/test_worker.h" // for TestDistributedGlobal
#include "../filesystem.h" // dmlc::TemporaryDirectory
#include "../helpers.h"
#include "test_predictor.h"
namespace xgboost {
TEST(CpuPredictor, Basic) {
Context ctx;
size_t constexpr kRows = 5;
@@ -56,9 +54,10 @@ TEST(CpuPredictor, IterationRangeColmnSplit) {
TEST(CpuPredictor, ExternalMemory) {
Context ctx;
size_t constexpr kPageSize = 64, kEntriesPerCol = 3;
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(kEntries);
bst_idx_t constexpr kRows{64};
bst_feature_t constexpr kCols{12};
auto dmat =
RandomDataGenerator{kRows, kCols, 0.5f}.Batches(3).GenerateSparsePageDMatrix("temp", true);
TestBasic(dmat.get(), &ctx);
}

View File

@@ -123,8 +123,8 @@ TEST(GPUPredictor, EllpackBasic) {
size_t rows = bins * 16;
auto p_m = RandomDataGenerator{rows, kCols, 0.0}
.Bins(bins)
.Device(DeviceOrd::CUDA(0))
.GenerateDeviceDMatrix(false);
.Device(ctx.Device())
.GenerateQuantileDMatrix(false);
ASSERT_FALSE(p_m->PageExists<SparsePage>());
TestPredictionFromGradientIndex<EllpackPage>(&ctx, rows, kCols, p_m);
TestPredictionFromGradientIndex<EllpackPage>(&ctx, bins, kCols, p_m);
@@ -137,7 +137,7 @@ TEST(GPUPredictor, EllpackTraining) {
auto p_ellpack = RandomDataGenerator{kRows, kCols, 0.0}
.Bins(kBins)
.Device(ctx.Device())
.GenerateDeviceDMatrix(false);
.GenerateQuantileDMatrix(false);
HostDeviceVector<float> storage(kRows * kCols);
auto columnar =
RandomDataGenerator{kRows, kCols, 0.0}.Device(ctx.Device()).GenerateArrayInterface(&storage);