[EM] Refactor ellpack construction. (#10810)

- Remove the calculation of n_symbols in the accessor.
- Pack initialization steps into the parameter list.
- Pass the context into various ctors.
- Specialization for dense data to prepare for further compression.
This commit is contained in:
Jiaming Yuan
2024-09-09 14:10:10 +08:00
committed by GitHub
parent c69c4adb58
commit 5f7f31d464
15 changed files with 187 additions and 158 deletions

View File

@@ -136,7 +136,7 @@ void TestBuildHist(bool use_shared_memory_histograms) {
feature_groups.DeviceAccessor(ctx.Device()), page->Cuts().TotalBins(),
!use_shared_memory_histograms);
builder.AllocateHistograms(&ctx, {0});
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(&ctx),
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(),
row_partitioner->GetRows(0), builder.GetNodeHistogram(0), *quantiser);
@@ -189,7 +189,7 @@ void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global)
DeviceHistogramBuilder builder;
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
feature_groups.DeviceAccessor(ctx.Device()), num_bins, force_global);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(&ctx),
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
d_histogram, quantiser);
@@ -205,7 +205,7 @@ void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global)
DeviceHistogramBuilder builder;
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
feature_groups.DeviceAccessor(ctx.Device()), num_bins, force_global);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(&ctx),
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
d_new_histogram, quantiser);
@@ -230,7 +230,7 @@ void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global)
DeviceHistogramBuilder builder;
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
single_group.DeviceAccessor(ctx.Device()), num_bins, force_global);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(&ctx),
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
dh::ToSpan(baseline), quantiser);
@@ -298,7 +298,7 @@ void TestGPUHistogramCategorical(size_t num_categories) {
DeviceHistogramBuilder builder;
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
single_group.DeviceAccessor(ctx.Device()), num_categories, false);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(&ctx),
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
dh::ToSpan(cat_hist), quantiser);
}
@@ -315,7 +315,7 @@ void TestGPUHistogramCategorical(size_t num_categories) {
DeviceHistogramBuilder builder;
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
single_group.DeviceAccessor(ctx.Device()), encode_hist.size(), false);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(&ctx),
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
dh::ToSpan(encode_hist), quantiser);
}
@@ -449,7 +449,7 @@ class HistogramExternalMemoryTest : public ::testing::TestWithParam<std::tuple<f
auto impl = page.Impl();
if (k == 0) {
// Initialization
auto d_matrix = impl->GetDeviceAccessor(ctx.Device());
auto d_matrix = impl->GetDeviceAccessor(&ctx);
fg = std::make_unique<FeatureGroups>(impl->Cuts());
auto init = GradientPairInt64{0, 0};
multi_hist = decltype(multi_hist)(impl->Cuts().TotalBins(), init);
@@ -465,7 +465,7 @@ class HistogramExternalMemoryTest : public ::testing::TestWithParam<std::tuple<f
DeviceHistogramBuilder builder;
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
fg->DeviceAccessor(ctx.Device()), d_histogram.size(), force_global);
builder.BuildHistogram(ctx.CUDACtx(), impl->GetDeviceAccessor(ctx.Device()),
builder.BuildHistogram(ctx.CUDACtx(), impl->GetDeviceAccessor(&ctx),
fg->DeviceAccessor(ctx.Device()), gpair.ConstDeviceSpan(), ridx,
d_histogram, quantiser);
++k;
@@ -491,7 +491,7 @@ class HistogramExternalMemoryTest : public ::testing::TestWithParam<std::tuple<f
DeviceHistogramBuilder builder;
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(), fg->DeviceAccessor(ctx.Device()),
d_histogram.size(), force_global);
builder.BuildHistogram(ctx.CUDACtx(), page.GetDeviceAccessor(ctx.Device()),
builder.BuildHistogram(ctx.CUDACtx(), page.GetDeviceAccessor(&ctx),
fg->DeviceAccessor(ctx.Device()), gpair.ConstDeviceSpan(), ridx,
d_histogram, quantiser);
}