[EM] Refactor ellpack construction. (#10810)

- Remove the calculation of n_symbols in the accessor.
- Pack initialization steps into the parameter list.
- Pass the context into various ctors.
- Specialization for dense data to prepare for further compression.
This commit is contained in:
Jiaming Yuan
2024-09-09 14:10:10 +08:00
committed by GitHub
parent c69c4adb58
commit 5f7f31d464
15 changed files with 187 additions and 158 deletions

View File

@@ -203,7 +203,8 @@ class BaseMGPUTest : public ::testing::Test {
* available.
*/
template <typename Fn>
auto DoTest(Fn&& fn, bool is_federated, [[maybe_unused]] bool emulate_if_single = false) const {
auto DoTest([[maybe_unused]] Fn&& fn, bool is_federated,
[[maybe_unused]] bool emulate_if_single = false) const {
auto n_gpus = common::AllVisibleGPUs();
if (is_federated) {
#if defined(XGBOOST_USE_FEDERATED)

View File

@@ -19,7 +19,7 @@ TEST(EllpackPage, EmptyDMatrix) {
constexpr int kNRows = 0, kNCols = 0, kMaxBin = 256;
constexpr float kSparsity = 0;
auto dmat = RandomDataGenerator(kNRows, kNCols, kSparsity).GenerateDMatrix();
Context ctx{MakeCUDACtx(0)};
auto ctx = MakeCUDACtx(0);
auto& page = *dmat->GetBatches<EllpackPage>(
&ctx, BatchParam{kMaxBin, tree::TrainParam::DftSparseThreshold()})
.begin();
@@ -94,7 +94,7 @@ TEST(EllpackPage, FromCategoricalBasic) {
Context ctx{MakeCUDACtx(0)};
auto p = BatchParam{max_bins, tree::TrainParam::DftSparseThreshold()};
auto ellpack = EllpackPage(&ctx, m.get(), p);
auto accessor = ellpack.Impl()->GetDeviceAccessor(ctx.Device());
auto accessor = ellpack.Impl()->GetDeviceAccessor(&ctx);
ASSERT_EQ(kCats, accessor.NumBins());
auto x_copy = x;
@@ -167,11 +167,11 @@ TEST(EllpackPage, Copy) {
EXPECT_EQ(impl->base_rowid, current_row);
for (size_t i = 0; i < impl->Size(); i++) {
dh::LaunchN(kCols, ReadRowFunction(impl->GetDeviceAccessor(ctx.Device()), current_row,
row_d.data().get()));
dh::LaunchN(kCols,
ReadRowFunction(impl->GetDeviceAccessor(&ctx), current_row, row_d.data().get()));
thrust::copy(row_d.begin(), row_d.end(), row.begin());
dh::LaunchN(kCols, ReadRowFunction(result.GetDeviceAccessor(ctx.Device()), current_row,
dh::LaunchN(kCols, ReadRowFunction(result.GetDeviceAccessor(&ctx), current_row,
row_result_d.data().get()));
thrust::copy(row_result_d.begin(), row_result_d.end(), row_result.begin());
@@ -223,12 +223,12 @@ TEST(EllpackPage, Compact) {
continue;
}
dh::LaunchN(kCols, ReadRowFunction(impl->GetDeviceAccessor(ctx.Device()), current_row,
row_d.data().get()));
dh::LaunchN(kCols,
ReadRowFunction(impl->GetDeviceAccessor(&ctx), current_row, row_d.data().get()));
dh::safe_cuda(cudaDeviceSynchronize());
thrust::copy(row_d.begin(), row_d.end(), row.begin());
dh::LaunchN(kCols, ReadRowFunction(result.GetDeviceAccessor(ctx.Device()), compacted_row,
dh::LaunchN(kCols, ReadRowFunction(result.GetDeviceAccessor(&ctx), compacted_row,
row_result_d.data().get()));
thrust::copy(row_result_d.begin(), row_result_d.end(), row_result.begin());

View File

@@ -27,7 +27,7 @@ void TestEquivalent(float sparsity) {
size_t num_elements = page_concatenated->Copy(&ctx, page, offset);
offset += num_elements;
}
auto from_iter = page_concatenated->GetDeviceAccessor(ctx.Device());
auto from_iter = page_concatenated->GetDeviceAccessor(&ctx);
ASSERT_EQ(m.Info().num_col_, CudaArrayIterForTest::Cols());
ASSERT_EQ(m.Info().num_row_, CudaArrayIterForTest::Rows());
@@ -37,7 +37,7 @@ void TestEquivalent(float sparsity) {
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 0)};
auto bp = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
for (auto& ellpack : dm->GetBatches<EllpackPage>(&ctx, bp)) {
auto from_data = ellpack.Impl()->GetDeviceAccessor(ctx.Device());
auto from_data = ellpack.Impl()->GetDeviceAccessor(&ctx);
std::vector<float> cuts_from_iter(from_iter.gidx_fvalue_map.size());
std::vector<float> min_fvalues_iter(from_iter.min_fvalue.size());
@@ -71,7 +71,7 @@ void TestEquivalent(float sparsity) {
auto data_buf = ellpack.Impl()->GetHostAccessor(&ctx, &buffer_from_data);
ASSERT_NE(buffer_from_data.size(), 0);
ASSERT_NE(buffer_from_iter.size(), 0);
CHECK_EQ(from_data.NumSymbols(), from_iter.NumSymbols());
CHECK_EQ(ellpack.Impl()->NumSymbols(), page_concatenated->NumSymbols());
CHECK_EQ(from_data.n_rows * from_data.row_stride, from_data.n_rows * from_iter.row_stride);
for (size_t i = 0; i < from_data.n_rows * from_data.row_stride; ++i) {
CHECK_EQ(data_buf.gidx_iter[i], data_iter.gidx_iter[i]);
@@ -146,10 +146,10 @@ TEST(IterativeDeviceDMatrix, RowMajorMissing) {
auto impl = ellpack.Impl();
std::vector<common::CompressedByteT> h_gidx;
auto h_accessor = impl->GetHostAccessor(&ctx, &h_gidx);
EXPECT_EQ(h_accessor.gidx_iter[1], impl->GetDeviceAccessor(ctx.Device()).NullValue());
EXPECT_EQ(h_accessor.gidx_iter[5], impl->GetDeviceAccessor(ctx.Device()).NullValue());
EXPECT_EQ(h_accessor.gidx_iter[1], impl->GetDeviceAccessor(&ctx).NullValue());
EXPECT_EQ(h_accessor.gidx_iter[5], impl->GetDeviceAccessor(&ctx).NullValue());
// null values get placed after valid values in a row
EXPECT_EQ(h_accessor.gidx_iter[7], impl->GetDeviceAccessor(ctx.Device()).NullValue());
EXPECT_EQ(h_accessor.gidx_iter[7], impl->GetDeviceAccessor(&ctx).NullValue());
EXPECT_EQ(m.Info().num_col_, cols);
EXPECT_EQ(m.Info().num_row_, rows);
EXPECT_EQ(m.Info().num_nonzero_, rows* cols - 3);

View File

@@ -14,7 +14,7 @@
namespace xgboost {
TEST(SparsePageDMatrix, EllpackPage) {
Context ctx{MakeCUDACtx(0)};
auto ctx = MakeCUDACtx(0);
auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
@@ -301,11 +301,11 @@ TEST(SparsePageDMatrix, MultipleEllpackPageContent) {
EXPECT_EQ(impl_ext->base_rowid, current_row);
for (size_t i = 0; i < impl_ext->Size(); i++) {
dh::LaunchN(kCols, ReadRowFunction(impl->GetDeviceAccessor(ctx.Device()), current_row,
row_d.data().get()));
dh::LaunchN(kCols,
ReadRowFunction(impl->GetDeviceAccessor(&ctx), current_row, row_d.data().get()));
thrust::copy(row_d.begin(), row_d.end(), row.begin());
dh::LaunchN(kCols, ReadRowFunction(impl_ext->GetDeviceAccessor(ctx.Device()), current_row,
dh::LaunchN(kCols, ReadRowFunction(impl_ext->GetDeviceAccessor(&ctx), current_row,
row_ext_d.data().get()));
thrust::copy(row_ext_d.begin(), row_ext_d.end(), row_ext.begin());

View File

@@ -136,7 +136,7 @@ void TestBuildHist(bool use_shared_memory_histograms) {
feature_groups.DeviceAccessor(ctx.Device()), page->Cuts().TotalBins(),
!use_shared_memory_histograms);
builder.AllocateHistograms(&ctx, {0});
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(&ctx),
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(),
row_partitioner->GetRows(0), builder.GetNodeHistogram(0), *quantiser);
@@ -189,7 +189,7 @@ void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global)
DeviceHistogramBuilder builder;
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
feature_groups.DeviceAccessor(ctx.Device()), num_bins, force_global);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(&ctx),
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
d_histogram, quantiser);
@@ -205,7 +205,7 @@ void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global)
DeviceHistogramBuilder builder;
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
feature_groups.DeviceAccessor(ctx.Device()), num_bins, force_global);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(&ctx),
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
d_new_histogram, quantiser);
@@ -230,7 +230,7 @@ void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global)
DeviceHistogramBuilder builder;
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
single_group.DeviceAccessor(ctx.Device()), num_bins, force_global);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(&ctx),
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
dh::ToSpan(baseline), quantiser);
@@ -298,7 +298,7 @@ void TestGPUHistogramCategorical(size_t num_categories) {
DeviceHistogramBuilder builder;
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
single_group.DeviceAccessor(ctx.Device()), num_categories, false);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(&ctx),
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
dh::ToSpan(cat_hist), quantiser);
}
@@ -315,7 +315,7 @@ void TestGPUHistogramCategorical(size_t num_categories) {
DeviceHistogramBuilder builder;
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
single_group.DeviceAccessor(ctx.Device()), encode_hist.size(), false);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(&ctx),
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
dh::ToSpan(encode_hist), quantiser);
}
@@ -449,7 +449,7 @@ class HistogramExternalMemoryTest : public ::testing::TestWithParam<std::tuple<f
auto impl = page.Impl();
if (k == 0) {
// Initialization
auto d_matrix = impl->GetDeviceAccessor(ctx.Device());
auto d_matrix = impl->GetDeviceAccessor(&ctx);
fg = std::make_unique<FeatureGroups>(impl->Cuts());
auto init = GradientPairInt64{0, 0};
multi_hist = decltype(multi_hist)(impl->Cuts().TotalBins(), init);
@@ -465,7 +465,7 @@ class HistogramExternalMemoryTest : public ::testing::TestWithParam<std::tuple<f
DeviceHistogramBuilder builder;
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(),
fg->DeviceAccessor(ctx.Device()), d_histogram.size(), force_global);
builder.BuildHistogram(ctx.CUDACtx(), impl->GetDeviceAccessor(ctx.Device()),
builder.BuildHistogram(ctx.CUDACtx(), impl->GetDeviceAccessor(&ctx),
fg->DeviceAccessor(ctx.Device()), gpair.ConstDeviceSpan(), ridx,
d_histogram, quantiser);
++k;
@@ -491,7 +491,7 @@ class HistogramExternalMemoryTest : public ::testing::TestWithParam<std::tuple<f
DeviceHistogramBuilder builder;
builder.Reset(&ctx, HistMakerTrainParam::CudaDefaultNodes(), fg->DeviceAccessor(ctx.Device()),
d_histogram.size(), force_global);
builder.BuildHistogram(ctx.CUDACtx(), page.GetDeviceAccessor(ctx.Device()),
builder.BuildHistogram(ctx.CUDACtx(), page.GetDeviceAccessor(&ctx),
fg->DeviceAccessor(ctx.Device()), gpair.ConstDeviceSpan(), ridx,
d_histogram, quantiser);
}