[EM] Enable access to the number of batches. (#10691)
- Expose `NumBatches` in `DMatrix`. - Small cleanup for removing legacy CUDA stream and ~force CUDA context initialization~. - Purge old external memory data generation code.
This commit is contained in:
@@ -7,22 +7,18 @@
|
||||
#include "../../../../src/tree/gpu_hist/gradient_based_sampler.cuh"
|
||||
#include "../../../../src/tree/param.h"
|
||||
#include "../../../../src/tree/param.h" // TrainParam
|
||||
#include "../../filesystem.h" // dmlc::TemporaryDirectory
|
||||
#include "../../helpers.h"
|
||||
|
||||
namespace xgboost::tree {
|
||||
void VerifySampling(size_t page_size,
|
||||
float subsample,
|
||||
int sampling_method,
|
||||
bool fixed_size_sampling = true,
|
||||
bool check_sum = true) {
|
||||
void VerifySampling(size_t page_size, float subsample, int sampling_method,
|
||||
bool fixed_size_sampling = true, bool check_sum = true) {
|
||||
constexpr size_t kRows = 4096;
|
||||
constexpr size_t kCols = 1;
|
||||
size_t sample_rows = kRows * subsample;
|
||||
bst_idx_t sample_rows = kRows * subsample;
|
||||
bst_idx_t n_batches = fixed_size_sampling ? 1 : 4;
|
||||
|
||||
dmlc::TemporaryDirectory tmpdir;
|
||||
std::unique_ptr<DMatrix> dmat(CreateSparsePageDMatrix(
|
||||
kRows, kCols, kRows / (page_size == 0 ? kRows : page_size), tmpdir.path + "/cache"));
|
||||
auto dmat = RandomDataGenerator{kRows, kCols, 0.0f}.Batches(n_batches).GenerateSparsePageDMatrix(
|
||||
"temp", true);
|
||||
auto gpair = GenerateRandomGradients(kRows);
|
||||
GradientPair sum_gpair{};
|
||||
for (const auto& gp : gpair.ConstHostVector()) {
|
||||
@@ -78,14 +74,12 @@ TEST(GradientBasedSampler, NoSamplingExternalMemory) {
|
||||
constexpr size_t kRows = 2048;
|
||||
constexpr size_t kCols = 1;
|
||||
constexpr float kSubsample = 1.0f;
|
||||
constexpr size_t kPageSize = 1024;
|
||||
|
||||
// Create a DMatrix with multiple batches.
|
||||
dmlc::TemporaryDirectory tmpdir;
|
||||
std::unique_ptr<DMatrix> dmat(
|
||||
CreateSparsePageDMatrix(kRows, kCols, kRows / kPageSize, tmpdir.path + "/cache"));
|
||||
auto dmat =
|
||||
RandomDataGenerator{kRows, kCols, 0.0f}.Batches(4).GenerateSparsePageDMatrix("temp", true);
|
||||
auto gpair = GenerateRandomGradients(kRows);
|
||||
Context ctx{MakeCUDACtx(0)};
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
gpair.SetDevice(ctx.Device());
|
||||
|
||||
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
|
||||
|
||||
Reference in New Issue
Block a user