Use Booster context in DMatrix. (#8896)

- Pass context from booster to DMatrix.
- Use context instead of integer for `n_threads`.
- Check the consistency configuration for `max_bin`.
- Test for all combinations of initialization options.
This commit is contained in:
Jiaming Yuan
2023-04-28 21:47:14 +08:00
committed by GitHub
parent 1f9a57d17b
commit 08ce495b5d
67 changed files with 1283 additions and 935 deletions

View File

@@ -208,17 +208,16 @@ TEST(GpuHist, TestHistogramIndex) {
TestHistogramIndexImpl();
}
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
size_t gpu_page_size, RegTree* tree,
HostDeviceVector<bst_float>* preds, float subsample = 1.0f,
const std::string& sampling_method = "uniform",
void UpdateTree(Context const* ctx, HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
size_t gpu_page_size, RegTree* tree, HostDeviceVector<bst_float>* preds,
float subsample = 1.0f, const std::string& sampling_method = "uniform",
int max_bin = 2) {
if (gpu_page_size > 0) {
// Loop over the batches and count the records
int64_t batch_count = 0;
int64_t row_count = 0;
for (const auto& batch : dmat->GetBatches<EllpackPage>({0, max_bin})) {
for (const auto& batch : dmat->GetBatches<EllpackPage>(
ctx, BatchParam{max_bin, TrainParam::DftSparseThreshold()})) {
EXPECT_LT(batch.Size(), dmat->Info().num_row_);
batch_count++;
row_count += batch.Size();
@@ -239,14 +238,13 @@ void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
TrainParam param;
param.UpdateAllowUnknown(args);
Context ctx(CreateEmptyGenericParam(0));
ObjInfo task{ObjInfo::kRegression};
tree::GPUHistMaker hist_maker{&ctx, &task};
tree::GPUHistMaker hist_maker{ctx, &task};
std::vector<HostDeviceVector<bst_node_t>> position(1);
hist_maker.Update(&param, gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position},
{tree});
auto cache = linalg::MakeTensorView(&ctx, preds->DeviceSpan(), preds->Size(), 1);
auto cache = linalg::MakeTensorView(ctx, preds->DeviceSpan(), preds->Size(), 1);
hist_maker.UpdatePredictionCache(dmat, cache);
}
@@ -264,12 +262,13 @@ TEST(GpuHist, UniformSampling) {
// Build a tree using the in-memory DMatrix.
RegTree tree;
HostDeviceVector<bst_float> preds(kRows, 0.0, 0);
UpdateTree(&gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows);
Context ctx(CreateEmptyGenericParam(0));
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows);
// Build another tree using sampling.
RegTree tree_sampling;
HostDeviceVector<bst_float> preds_sampling(kRows, 0.0, 0);
UpdateTree(&gpair, dmat.get(), 0, &tree_sampling, &preds_sampling, kSubsample,
"uniform", kRows);
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree_sampling, &preds_sampling, kSubsample, "uniform",
kRows);
// Make sure the predictions are the same.
auto preds_h = preds.ConstHostVector();
@@ -293,12 +292,13 @@ TEST(GpuHist, GradientBasedSampling) {
// Build a tree using the in-memory DMatrix.
RegTree tree;
HostDeviceVector<bst_float> preds(kRows, 0.0, 0);
UpdateTree(&gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows);
Context ctx(CreateEmptyGenericParam(0));
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows);
// Build another tree using sampling.
RegTree tree_sampling;
HostDeviceVector<bst_float> preds_sampling(kRows, 0.0, 0);
UpdateTree(&gpair, dmat.get(), 0, &tree_sampling, &preds_sampling, kSubsample,
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree_sampling, &preds_sampling, kSubsample,
"gradient_based", kRows);
// Make sure the predictions are the same.
@@ -327,12 +327,13 @@ TEST(GpuHist, ExternalMemory) {
// Build a tree using the in-memory DMatrix.
RegTree tree;
Context ctx(CreateEmptyGenericParam(0));
HostDeviceVector<bst_float> preds(kRows, 0.0, 0);
UpdateTree(&gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows);
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows);
// Build another tree using multiple ELLPACK pages.
RegTree tree_ext;
HostDeviceVector<bst_float> preds_ext(kRows, 0.0, 0);
UpdateTree(&gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext, 1.0, "uniform", kRows);
UpdateTree(&ctx, &gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext, 1.0, "uniform", kRows);
// Make sure the predictions are the same.
auto preds_h = preds.ConstHostVector();
@@ -364,17 +365,17 @@ TEST(GpuHist, ExternalMemoryWithSampling) {
// Build a tree using the in-memory DMatrix.
auto rng = common::GlobalRandom();
Context ctx(CreateEmptyGenericParam(0));
RegTree tree;
HostDeviceVector<bst_float> preds(kRows, 0.0, 0);
UpdateTree(&gpair, dmat.get(), 0, &tree, &preds, kSubsample, kSamplingMethod,
kRows);
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree, &preds, kSubsample, kSamplingMethod, kRows);
// Build another tree using multiple ELLPACK pages.
common::GlobalRandom() = rng;
RegTree tree_ext;
HostDeviceVector<bst_float> preds_ext(kRows, 0.0, 0);
UpdateTree(&gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext,
kSubsample, kSamplingMethod, kRows);
UpdateTree(&ctx, &gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext, kSubsample,
kSamplingMethod, kRows);
// Make sure the predictions are the same.
auto preds_h = preds.ConstHostVector();