[EM] Refactor ellpack construction. (#10810)
- Remove the calculation of n_symbols in the accessor. - Pack initialization steps into the parameter list. - Pass the context into various ctors. - Specialization for dense data to prepare for further compression.
This commit is contained in:
@@ -136,7 +136,7 @@ void TestBuildHist(bool use_shared_memory_histograms) {
|
||||
feature_groups.DeviceAccessor(ctx.Device()), page->Cuts().TotalBins(),
|
||||
!use_shared_memory_histograms);
|
||||
builder.AllocateHistograms(&ctx, {0});
|
||||
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
|
||||
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(&ctx),
|
||||
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(),
|
||||
row_partitioner->GetRows(0), builder.GetNodeHistogram(0), *quantiser);
|
||||
|
||||
@@ -189,7 +189,7 @@ void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global)
|
||||
DeviceHistogramBuilder builder;
|
||||
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
|
||||
feature_groups.DeviceAccessor(ctx.Device()), num_bins, force_global);
|
||||
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
|
||||
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(&ctx),
|
||||
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
|
||||
d_histogram, quantiser);
|
||||
|
||||
@@ -205,7 +205,7 @@ void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global)
|
||||
DeviceHistogramBuilder builder;
|
||||
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
|
||||
feature_groups.DeviceAccessor(ctx.Device()), num_bins, force_global);
|
||||
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
|
||||
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(&ctx),
|
||||
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
|
||||
d_new_histogram, quantiser);
|
||||
|
||||
@@ -230,7 +230,7 @@ void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global)
|
||||
DeviceHistogramBuilder builder;
|
||||
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
|
||||
single_group.DeviceAccessor(ctx.Device()), num_bins, force_global);
|
||||
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
|
||||
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(&ctx),
|
||||
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
|
||||
dh::ToSpan(baseline), quantiser);
|
||||
|
||||
@@ -298,7 +298,7 @@ void TestGPUHistogramCategorical(size_t num_categories) {
|
||||
DeviceHistogramBuilder builder;
|
||||
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
|
||||
single_group.DeviceAccessor(ctx.Device()), num_categories, false);
|
||||
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
|
||||
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(&ctx),
|
||||
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
|
||||
dh::ToSpan(cat_hist), quantiser);
|
||||
}
|
||||
@@ -315,7 +315,7 @@ void TestGPUHistogramCategorical(size_t num_categories) {
|
||||
DeviceHistogramBuilder builder;
|
||||
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
|
||||
single_group.DeviceAccessor(ctx.Device()), encode_hist.size(), false);
|
||||
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
|
||||
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(&ctx),
|
||||
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
|
||||
dh::ToSpan(encode_hist), quantiser);
|
||||
}
|
||||
@@ -449,7 +449,7 @@ class HistogramExternalMemoryTest : public ::testing::TestWithParam<std::tuple<f
|
||||
auto impl = page.Impl();
|
||||
if (k == 0) {
|
||||
// Initialization
|
||||
auto d_matrix = impl->GetDeviceAccessor(ctx.Device());
|
||||
auto d_matrix = impl->GetDeviceAccessor(&ctx);
|
||||
fg = std::make_unique<FeatureGroups>(impl->Cuts());
|
||||
auto init = GradientPairInt64{0, 0};
|
||||
multi_hist = decltype(multi_hist)(impl->Cuts().TotalBins(), init);
|
||||
@@ -465,7 +465,7 @@ class HistogramExternalMemoryTest : public ::testing::TestWithParam<std::tuple<f
|
||||
DeviceHistogramBuilder builder;
|
||||
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
|
||||
fg->DeviceAccessor(ctx.Device()), d_histogram.size(), force_global);
|
||||
builder.BuildHistogram(ctx.CUDACtx(), impl->GetDeviceAccessor(ctx.Device()),
|
||||
builder.BuildHistogram(ctx.CUDACtx(), impl->GetDeviceAccessor(&ctx),
|
||||
fg->DeviceAccessor(ctx.Device()), gpair.ConstDeviceSpan(), ridx,
|
||||
d_histogram, quantiser);
|
||||
++k;
|
||||
@@ -491,7 +491,7 @@ class HistogramExternalMemoryTest : public ::testing::TestWithParam<std::tuple<f
|
||||
DeviceHistogramBuilder builder;
|
||||
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(), fg->DeviceAccessor(ctx.Device()),
|
||||
d_histogram.size(), force_global);
|
||||
builder.BuildHistogram(ctx.CUDACtx(), page.GetDeviceAccessor(ctx.Device()),
|
||||
builder.BuildHistogram(ctx.CUDACtx(), page.GetDeviceAccessor(&ctx),
|
||||
fg->DeviceAccessor(ctx.Device()), gpair.ConstDeviceSpan(), ridx,
|
||||
d_histogram, quantiser);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user