Use Booster context in DMatrix. (#8896)
- Pass context from booster to DMatrix. - Use context instead of integer for `n_threads`. - Check the consistency configuration for `max_bin`. - Test for all combinations of initialization options.
This commit is contained in:
@@ -208,17 +208,16 @@ TEST(GpuHist, TestHistogramIndex) {
|
||||
TestHistogramIndexImpl();
|
||||
}
|
||||
|
||||
void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
size_t gpu_page_size, RegTree* tree,
|
||||
HostDeviceVector<bst_float>* preds, float subsample = 1.0f,
|
||||
const std::string& sampling_method = "uniform",
|
||||
void UpdateTree(Context const* ctx, HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
size_t gpu_page_size, RegTree* tree, HostDeviceVector<bst_float>* preds,
|
||||
float subsample = 1.0f, const std::string& sampling_method = "uniform",
|
||||
int max_bin = 2) {
|
||||
|
||||
if (gpu_page_size > 0) {
|
||||
// Loop over the batches and count the records
|
||||
int64_t batch_count = 0;
|
||||
int64_t row_count = 0;
|
||||
for (const auto& batch : dmat->GetBatches<EllpackPage>({0, max_bin})) {
|
||||
for (const auto& batch : dmat->GetBatches<EllpackPage>(
|
||||
ctx, BatchParam{max_bin, TrainParam::DftSparseThreshold()})) {
|
||||
EXPECT_LT(batch.Size(), dmat->Info().num_row_);
|
||||
batch_count++;
|
||||
row_count += batch.Size();
|
||||
@@ -239,14 +238,13 @@ void UpdateTree(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
TrainParam param;
|
||||
param.UpdateAllowUnknown(args);
|
||||
|
||||
Context ctx(CreateEmptyGenericParam(0));
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
tree::GPUHistMaker hist_maker{&ctx, &task};
|
||||
tree::GPUHistMaker hist_maker{ctx, &task};
|
||||
|
||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||
hist_maker.Update(¶m, gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position},
|
||||
{tree});
|
||||
auto cache = linalg::MakeTensorView(&ctx, preds->DeviceSpan(), preds->Size(), 1);
|
||||
auto cache = linalg::MakeTensorView(ctx, preds->DeviceSpan(), preds->Size(), 1);
|
||||
hist_maker.UpdatePredictionCache(dmat, cache);
|
||||
}
|
||||
|
||||
@@ -264,12 +262,13 @@ TEST(GpuHist, UniformSampling) {
|
||||
// Build a tree using the in-memory DMatrix.
|
||||
RegTree tree;
|
||||
HostDeviceVector<bst_float> preds(kRows, 0.0, 0);
|
||||
UpdateTree(&gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows);
|
||||
Context ctx(CreateEmptyGenericParam(0));
|
||||
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows);
|
||||
// Build another tree using sampling.
|
||||
RegTree tree_sampling;
|
||||
HostDeviceVector<bst_float> preds_sampling(kRows, 0.0, 0);
|
||||
UpdateTree(&gpair, dmat.get(), 0, &tree_sampling, &preds_sampling, kSubsample,
|
||||
"uniform", kRows);
|
||||
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree_sampling, &preds_sampling, kSubsample, "uniform",
|
||||
kRows);
|
||||
|
||||
// Make sure the predictions are the same.
|
||||
auto preds_h = preds.ConstHostVector();
|
||||
@@ -293,12 +292,13 @@ TEST(GpuHist, GradientBasedSampling) {
|
||||
// Build a tree using the in-memory DMatrix.
|
||||
RegTree tree;
|
||||
HostDeviceVector<bst_float> preds(kRows, 0.0, 0);
|
||||
UpdateTree(&gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows);
|
||||
Context ctx(CreateEmptyGenericParam(0));
|
||||
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows);
|
||||
|
||||
// Build another tree using sampling.
|
||||
RegTree tree_sampling;
|
||||
HostDeviceVector<bst_float> preds_sampling(kRows, 0.0, 0);
|
||||
UpdateTree(&gpair, dmat.get(), 0, &tree_sampling, &preds_sampling, kSubsample,
|
||||
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree_sampling, &preds_sampling, kSubsample,
|
||||
"gradient_based", kRows);
|
||||
|
||||
// Make sure the predictions are the same.
|
||||
@@ -327,12 +327,13 @@ TEST(GpuHist, ExternalMemory) {
|
||||
|
||||
// Build a tree using the in-memory DMatrix.
|
||||
RegTree tree;
|
||||
Context ctx(CreateEmptyGenericParam(0));
|
||||
HostDeviceVector<bst_float> preds(kRows, 0.0, 0);
|
||||
UpdateTree(&gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows);
|
||||
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree, &preds, 1.0, "uniform", kRows);
|
||||
// Build another tree using multiple ELLPACK pages.
|
||||
RegTree tree_ext;
|
||||
HostDeviceVector<bst_float> preds_ext(kRows, 0.0, 0);
|
||||
UpdateTree(&gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext, 1.0, "uniform", kRows);
|
||||
UpdateTree(&ctx, &gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext, 1.0, "uniform", kRows);
|
||||
|
||||
// Make sure the predictions are the same.
|
||||
auto preds_h = preds.ConstHostVector();
|
||||
@@ -364,17 +365,17 @@ TEST(GpuHist, ExternalMemoryWithSampling) {
|
||||
// Build a tree using the in-memory DMatrix.
|
||||
auto rng = common::GlobalRandom();
|
||||
|
||||
Context ctx(CreateEmptyGenericParam(0));
|
||||
RegTree tree;
|
||||
HostDeviceVector<bst_float> preds(kRows, 0.0, 0);
|
||||
UpdateTree(&gpair, dmat.get(), 0, &tree, &preds, kSubsample, kSamplingMethod,
|
||||
kRows);
|
||||
UpdateTree(&ctx, &gpair, dmat.get(), 0, &tree, &preds, kSubsample, kSamplingMethod, kRows);
|
||||
|
||||
// Build another tree using multiple ELLPACK pages.
|
||||
common::GlobalRandom() = rng;
|
||||
RegTree tree_ext;
|
||||
HostDeviceVector<bst_float> preds_ext(kRows, 0.0, 0);
|
||||
UpdateTree(&gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext,
|
||||
kSubsample, kSamplingMethod, kRows);
|
||||
UpdateTree(&ctx, &gpair, dmat_ext.get(), kPageSize, &tree_ext, &preds_ext, kSubsample,
|
||||
kSamplingMethod, kRows);
|
||||
|
||||
// Make sure the predictions are the same.
|
||||
auto preds_h = preds.ConstHostVector();
|
||||
|
||||
Reference in New Issue
Block a user