Support column split with GPU quantile (#9370)

This commit is contained in:
Rong Ou 2023-07-10 21:15:56 -07:00 committed by GitHub
parent 97ed944209
commit 3632242e0b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 68 additions and 16 deletions

View File

@ -352,7 +352,7 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
}
}
}
sketch_container.MakeCuts(&cuts);
sketch_container.MakeCuts(&cuts, dmat->Info().IsColumnSplit());
return cuts;
}
} // namespace common

View File

@ -501,10 +501,10 @@ void SketchContainer::FixError() {
});
}
void SketchContainer::AllReduce() {
void SketchContainer::AllReduce(bool is_column_split) {
dh::safe_cuda(cudaSetDevice(device_));
auto world = collective::GetWorldSize();
if (world == 1) {
if (world == 1 || is_column_split) {
return;
}
@ -582,13 +582,13 @@ struct InvalidCatOp {
};
} // anonymous namespace
void SketchContainer::MakeCuts(HistogramCuts* p_cuts) {
void SketchContainer::MakeCuts(HistogramCuts* p_cuts, bool is_column_split) {
timer_.Start(__func__);
dh::safe_cuda(cudaSetDevice(device_));
p_cuts->min_vals_.Resize(num_columns_);
// Sync between workers.
this->AllReduce();
this->AllReduce(is_column_split);
// Prune to final number of bins.
this->Prune(num_bins_ + 1);

View File

@ -154,9 +154,9 @@ class SketchContainer {
Span<SketchEntry const> that);
/* \brief Merge quantiles from other GPU workers. */
void AllReduce();
void AllReduce(bool is_column_split);
/* \brief Create the final histogram cut values. */
void MakeCuts(HistogramCuts* cuts);
void MakeCuts(HistogramCuts* cuts, bool is_column_split);
Span<SketchEntry const> Data() const {
return {this->Current().data().get(), this->Current().size()};

View File

@ -106,7 +106,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
sketch_containers.clear();
sketch_containers.shrink_to_fit();
final_sketch.MakeCuts(&cuts);
final_sketch.MakeCuts(&cuts, this->info_.IsColumnSplit());
} else {
GetCutsFromRef(ctx, ref, Info().num_col_, p, &cuts);
}

View File

@ -351,7 +351,7 @@ auto MakeUnweightedCutsForTest(Adapter adapter, int32_t num_bins, float missing,
SketchContainer sketch_container(ft, num_bins, adapter.NumColumns(), adapter.NumRows(), 0);
MetaInfo info;
AdapterDeviceSketch(adapter.Value(), num_bins, info, missing, &sketch_container, batch_size);
sketch_container.MakeCuts(&batched_cuts);
sketch_container.MakeCuts(&batched_cuts, info.IsColumnSplit());
return batched_cuts;
}
@ -419,7 +419,7 @@ TEST(HistUtil, AdapterSketchSlidingWindowMemory) {
AdapterDeviceSketch(adapter.Value(), num_bins, info, std::numeric_limits<float>::quiet_NaN(),
&sketch_container);
HistogramCuts cuts;
sketch_container.MakeCuts(&cuts);
sketch_container.MakeCuts(&cuts, info.IsColumnSplit());
size_t bytes_required = detail::RequiredMemory(
num_rows, num_columns, num_rows * num_columns, num_bins, false);
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05);
@ -449,7 +449,7 @@ TEST(HistUtil, AdapterSketchSlidingWindowWeightedMemory) {
&sketch_container);
HistogramCuts cuts;
sketch_container.MakeCuts(&cuts);
sketch_container.MakeCuts(&cuts, info.IsColumnSplit());
ConsoleLogger::Configure({{"verbosity", "0"}});
size_t bytes_required = detail::RequiredMemory(
num_rows, num_columns, num_rows * num_columns, num_bins, true);
@ -482,7 +482,7 @@ void TestCategoricalSketchAdapter(size_t n, size_t num_categories,
AdapterDeviceSketch(adapter.Value(), num_bins, info,
std::numeric_limits<float>::quiet_NaN(), &container);
HistogramCuts cuts;
container.MakeCuts(&cuts);
container.MakeCuts(&cuts, info.IsColumnSplit());
thrust::sort(x.begin(), x.end());
auto n_uniques = thrust::unique(x.begin(), x.end()) - x.begin();
@ -710,7 +710,7 @@ void TestAdapterSketchFromWeights(bool with_group) {
&sketch_container);
common::HistogramCuts cuts;
sketch_container.MakeCuts(&cuts);
sketch_container.MakeCuts(&cuts, info.IsColumnSplit());
auto dmat = GetDMatrixFromData(storage.HostVector(), kRows, kCols);
if (with_group) {
@ -751,7 +751,7 @@ void TestAdapterSketchFromWeights(bool with_group) {
SketchContainer sketch_container(ft, kBins, kCols, kRows, 0);
AdapterDeviceSketch(adapter.Value(), kBins, info, std::numeric_limits<float>::quiet_NaN(),
&sketch_container);
sketch_container.MakeCuts(&weighted);
sketch_container.MakeCuts(&weighted, info.IsColumnSplit());
ValidateCuts(weighted, dmat.get(), kBins);
}
}

View File

@ -388,7 +388,7 @@ void TestAllReduceBasic(int32_t n_gpus) {
AdapterDeviceSketch(adapter.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(),
&sketch_distributed);
sketch_distributed.AllReduce();
sketch_distributed.AllReduce(false);
sketch_distributed.Unique();
ASSERT_EQ(sketch_distributed.ColumnsPtr().size(),
@ -425,6 +425,58 @@ TEST(GPUQuantile, MGPUAllReduceBasic) {
RunWithInMemoryCommunicator(n_gpus, TestAllReduceBasic, n_gpus);
}
namespace {
void TestColumnSplitBasic() {
auto const world = collective::GetWorldSize();
auto const rank = collective::GetRank();
std::size_t constexpr kRows = 1000, kCols = 100, kBins = 64;
auto m = std::unique_ptr<DMatrix>{[=]() {
auto dmat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix();
return dmat->SliceCol(world, rank);
}()};
// Generate cuts for distributed environment.
auto const device = rank;
HistogramCuts distributed_cuts = common::DeviceSketch(device, m.get(), kBins);
// Generate cuts for single node environment
collective::Finalize();
CHECK_EQ(collective::GetWorldSize(), 1);
HistogramCuts single_node_cuts = common::DeviceSketch(device, m.get(), kBins);
auto const& sptrs = single_node_cuts.Ptrs();
auto const& dptrs = distributed_cuts.Ptrs();
auto const& svals = single_node_cuts.Values();
auto const& dvals = distributed_cuts.Values();
auto const& smins = single_node_cuts.MinValues();
auto const& dmins = distributed_cuts.MinValues();
EXPECT_EQ(sptrs.size(), dptrs.size());
for (size_t i = 0; i < sptrs.size(); ++i) {
EXPECT_EQ(sptrs[i], dptrs[i]) << "rank: " << rank << ", i: " << i;
}
EXPECT_EQ(svals.size(), dvals.size());
for (size_t i = 0; i < svals.size(); ++i) {
EXPECT_NEAR(svals[i], dvals[i], 2e-2f) << "rank: " << rank << ", i: " << i;
}
EXPECT_EQ(smins.size(), dmins.size());
for (size_t i = 0; i < smins.size(); ++i) {
EXPECT_FLOAT_EQ(smins[i], dmins[i]) << "rank: " << rank << ", i: " << i;
}
}
} // anonymous namespace
TEST(GPUQuantile, MGPUColumnSplitBasic) {
auto const n_gpus = AllVisibleGPUs();
if (n_gpus <= 1) {
GTEST_SKIP() << "Skipping MGPUColumnSplitBasic test with # GPUs = " << n_gpus;
}
RunWithInMemoryCommunicator(n_gpus, TestColumnSplitBasic);
}
namespace {
void TestSameOnAllWorkers(std::int32_t n_gpus) {
auto world = collective::GetWorldSize();
@ -445,7 +497,7 @@ void TestSameOnAllWorkers(std::int32_t n_gpus) {
AdapterDeviceSketch(adapter.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(),
&sketch_distributed);
sketch_distributed.AllReduce();
sketch_distributed.AllReduce(false);
sketch_distributed.Unique();
TestQuantileElemRank(device, sketch_distributed.Data(), sketch_distributed.ColumnsPtr(), true);