From 3632242e0b680592a0bbae7b086c42e52741cfaa Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 10 Jul 2023 21:15:56 -0700 Subject: [PATCH] Support column split with GPU quantile (#9370) --- src/common/hist_util.cu | 2 +- src/common/quantile.cu | 8 ++--- src/common/quantile.cuh | 4 +-- src/data/iterative_dmatrix.cu | 2 +- tests/cpp/common/test_hist_util.cu | 12 +++---- tests/cpp/common/test_quantile.cu | 56 ++++++++++++++++++++++++++++-- 6 files changed, 68 insertions(+), 16 deletions(-) diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 76fff8a98..1c9525a62 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -352,7 +352,7 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, } } } - sketch_container.MakeCuts(&cuts); + sketch_container.MakeCuts(&cuts, dmat->Info().IsColumnSplit()); return cuts; } } // namespace common diff --git a/src/common/quantile.cu b/src/common/quantile.cu index 5c81ec2ea..25c4543c6 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -501,10 +501,10 @@ void SketchContainer::FixError() { }); } -void SketchContainer::AllReduce() { +void SketchContainer::AllReduce(bool is_column_split) { dh::safe_cuda(cudaSetDevice(device_)); auto world = collective::GetWorldSize(); - if (world == 1) { + if (world == 1 || is_column_split) { return; } @@ -582,13 +582,13 @@ struct InvalidCatOp { }; } // anonymous namespace -void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { +void SketchContainer::MakeCuts(HistogramCuts* p_cuts, bool is_column_split) { timer_.Start(__func__); dh::safe_cuda(cudaSetDevice(device_)); p_cuts->min_vals_.Resize(num_columns_); // Sync between workers. - this->AllReduce(); + this->AllReduce(is_column_split); // Prune to final number of bins. this->Prune(num_bins_ + 1); diff --git a/src/common/quantile.cuh b/src/common/quantile.cuh index 7ebd4ff51..fedbdbd82 100644 --- a/src/common/quantile.cuh +++ b/src/common/quantile.cuh @@ -154,9 +154,9 @@ class SketchContainer { Span that); /* \brief Merge quantiles from other GPU workers. */ - void AllReduce(); + void AllReduce(bool is_column_split); /* \brief Create the final histogram cut values. */ - void MakeCuts(HistogramCuts* cuts); + void MakeCuts(HistogramCuts* cuts, bool is_column_split); Span Data() const { return {this->Current().data().get(), this->Current().size()}; diff --git a/src/data/iterative_dmatrix.cu b/src/data/iterative_dmatrix.cu index a760ec9ab..1e74cb23c 100644 --- a/src/data/iterative_dmatrix.cu +++ b/src/data/iterative_dmatrix.cu @@ -106,7 +106,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p, sketch_containers.clear(); sketch_containers.shrink_to_fit(); - final_sketch.MakeCuts(&cuts); + final_sketch.MakeCuts(&cuts, this->info_.IsColumnSplit()); } else { GetCutsFromRef(ctx, ref, Info().num_col_, p, &cuts); } diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index 20fd1043d..127cd95d4 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -351,7 +351,7 @@ auto MakeUnweightedCutsForTest(Adapter adapter, int32_t num_bins, float missing, SketchContainer sketch_container(ft, num_bins, adapter.NumColumns(), adapter.NumRows(), 0); MetaInfo info; AdapterDeviceSketch(adapter.Value(), num_bins, info, missing, &sketch_container, batch_size); - sketch_container.MakeCuts(&batched_cuts); + sketch_container.MakeCuts(&batched_cuts, info.IsColumnSplit()); return batched_cuts; } @@ -419,7 +419,7 @@ TEST(HistUtil, AdapterSketchSlidingWindowMemory) { AdapterDeviceSketch(adapter.Value(), num_bins, info, std::numeric_limits::quiet_NaN(), &sketch_container); HistogramCuts cuts; - sketch_container.MakeCuts(&cuts); + sketch_container.MakeCuts(&cuts, info.IsColumnSplit()); size_t bytes_required = detail::RequiredMemory( num_rows, num_columns, num_rows * num_columns, num_bins, false); EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05); @@ -449,7 +449,7 @@ TEST(HistUtil, AdapterSketchSlidingWindowWeightedMemory) { &sketch_container); HistogramCuts cuts; - sketch_container.MakeCuts(&cuts); + sketch_container.MakeCuts(&cuts, info.IsColumnSplit()); ConsoleLogger::Configure({{"verbosity", "0"}}); size_t bytes_required = detail::RequiredMemory( num_rows, num_columns, num_rows * num_columns, num_bins, true); @@ -482,7 +482,7 @@ void TestCategoricalSketchAdapter(size_t n, size_t num_categories, AdapterDeviceSketch(adapter.Value(), num_bins, info, std::numeric_limits::quiet_NaN(), &container); HistogramCuts cuts; - container.MakeCuts(&cuts); + container.MakeCuts(&cuts, info.IsColumnSplit()); thrust::sort(x.begin(), x.end()); auto n_uniques = thrust::unique(x.begin(), x.end()) - x.begin(); @@ -710,7 +710,7 @@ void TestAdapterSketchFromWeights(bool with_group) { &sketch_container); common::HistogramCuts cuts; - sketch_container.MakeCuts(&cuts); + sketch_container.MakeCuts(&cuts, info.IsColumnSplit()); auto dmat = GetDMatrixFromData(storage.HostVector(), kRows, kCols); if (with_group) { @@ -751,7 +751,7 @@ void TestAdapterSketchFromWeights(bool with_group) { SketchContainer sketch_container(ft, kBins, kCols, kRows, 0); AdapterDeviceSketch(adapter.Value(), kBins, info, std::numeric_limits::quiet_NaN(), &sketch_container); - sketch_container.MakeCuts(&weighted); + sketch_container.MakeCuts(&weighted, info.IsColumnSplit()); ValidateCuts(weighted, dmat.get(), kBins); } } diff --git a/tests/cpp/common/test_quantile.cu b/tests/cpp/common/test_quantile.cu index 935d88ab6..d2dc802a9 100644 --- a/tests/cpp/common/test_quantile.cu +++ b/tests/cpp/common/test_quantile.cu @@ -388,7 +388,7 @@ void TestAllReduceBasic(int32_t n_gpus) { AdapterDeviceSketch(adapter.Value(), n_bins, info, std::numeric_limits::quiet_NaN(), &sketch_distributed); - sketch_distributed.AllReduce(); + sketch_distributed.AllReduce(false); sketch_distributed.Unique(); ASSERT_EQ(sketch_distributed.ColumnsPtr().size(), @@ -425,6 +425,58 @@ TEST(GPUQuantile, MGPUAllReduceBasic) { RunWithInMemoryCommunicator(n_gpus, TestAllReduceBasic, n_gpus); } +namespace { +void TestColumnSplitBasic() { + auto const world = collective::GetWorldSize(); + auto const rank = collective::GetRank(); + std::size_t constexpr kRows = 1000, kCols = 100, kBins = 64; + + auto m = std::unique_ptr{[=]() { + auto dmat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(); + return dmat->SliceCol(world, rank); + }()}; + + // Generate cuts for distributed environment. + auto const device = rank; + HistogramCuts distributed_cuts = common::DeviceSketch(device, m.get(), kBins); + + // Generate cuts for single node environment + collective::Finalize(); + CHECK_EQ(collective::GetWorldSize(), 1); + HistogramCuts single_node_cuts = common::DeviceSketch(device, m.get(), kBins); + + auto const& sptrs = single_node_cuts.Ptrs(); + auto const& dptrs = distributed_cuts.Ptrs(); + auto const& svals = single_node_cuts.Values(); + auto const& dvals = distributed_cuts.Values(); + auto const& smins = single_node_cuts.MinValues(); + auto const& dmins = distributed_cuts.MinValues(); + + EXPECT_EQ(sptrs.size(), dptrs.size()); + for (size_t i = 0; i < sptrs.size(); ++i) { + EXPECT_EQ(sptrs[i], dptrs[i]) << "rank: " << rank << ", i: " << i; + } + + EXPECT_EQ(svals.size(), dvals.size()); + for (size_t i = 0; i < svals.size(); ++i) { + EXPECT_NEAR(svals[i], dvals[i], 2e-2f) << "rank: " << rank << ", i: " << i; + } + + EXPECT_EQ(smins.size(), dmins.size()); + for (size_t i = 0; i < smins.size(); ++i) { + EXPECT_FLOAT_EQ(smins[i], dmins[i]) << "rank: " << rank << ", i: " << i; + } +} +} // anonymous namespace + +TEST(GPUQuantile, MGPUColumnSplitBasic) { + auto const n_gpus = AllVisibleGPUs(); + if (n_gpus <= 1) { + GTEST_SKIP() << "Skipping MGPUColumnSplitBasic test with # GPUs = " << n_gpus; + } + RunWithInMemoryCommunicator(n_gpus, TestColumnSplitBasic); +} + namespace { void TestSameOnAllWorkers(std::int32_t n_gpus) { auto world = collective::GetWorldSize(); @@ -445,7 +497,7 @@ void TestSameOnAllWorkers(std::int32_t n_gpus) { AdapterDeviceSketch(adapter.Value(), n_bins, info, std::numeric_limits::quiet_NaN(), &sketch_distributed); - sketch_distributed.AllReduce(); + sketch_distributed.AllReduce(false); sketch_distributed.Unique(); TestQuantileElemRank(device, sketch_distributed.Data(), sketch_distributed.ColumnsPtr(), true);