Support column split with GPU quantile (#9370)
This commit is contained in:
parent
97ed944209
commit
3632242e0b
@ -352,7 +352,7 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
|
||||
}
|
||||
}
|
||||
}
|
||||
sketch_container.MakeCuts(&cuts);
|
||||
sketch_container.MakeCuts(&cuts, dmat->Info().IsColumnSplit());
|
||||
return cuts;
|
||||
}
|
||||
} // namespace common
|
||||
|
||||
@ -501,10 +501,10 @@ void SketchContainer::FixError() {
|
||||
});
|
||||
}
|
||||
|
||||
void SketchContainer::AllReduce() {
|
||||
void SketchContainer::AllReduce(bool is_column_split) {
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
auto world = collective::GetWorldSize();
|
||||
if (world == 1) {
|
||||
if (world == 1 || is_column_split) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -582,13 +582,13 @@ struct InvalidCatOp {
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
void SketchContainer::MakeCuts(HistogramCuts* p_cuts) {
|
||||
void SketchContainer::MakeCuts(HistogramCuts* p_cuts, bool is_column_split) {
|
||||
timer_.Start(__func__);
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
p_cuts->min_vals_.Resize(num_columns_);
|
||||
|
||||
// Sync between workers.
|
||||
this->AllReduce();
|
||||
this->AllReduce(is_column_split);
|
||||
|
||||
// Prune to final number of bins.
|
||||
this->Prune(num_bins_ + 1);
|
||||
|
||||
@ -154,9 +154,9 @@ class SketchContainer {
|
||||
Span<SketchEntry const> that);
|
||||
|
||||
/* \brief Merge quantiles from other GPU workers. */
|
||||
void AllReduce();
|
||||
void AllReduce(bool is_column_split);
|
||||
/* \brief Create the final histogram cut values. */
|
||||
void MakeCuts(HistogramCuts* cuts);
|
||||
void MakeCuts(HistogramCuts* cuts, bool is_column_split);
|
||||
|
||||
Span<SketchEntry const> Data() const {
|
||||
return {this->Current().data().get(), this->Current().size()};
|
||||
|
||||
@ -106,7 +106,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
|
||||
sketch_containers.clear();
|
||||
sketch_containers.shrink_to_fit();
|
||||
|
||||
final_sketch.MakeCuts(&cuts);
|
||||
final_sketch.MakeCuts(&cuts, this->info_.IsColumnSplit());
|
||||
} else {
|
||||
GetCutsFromRef(ctx, ref, Info().num_col_, p, &cuts);
|
||||
}
|
||||
|
||||
@ -351,7 +351,7 @@ auto MakeUnweightedCutsForTest(Adapter adapter, int32_t num_bins, float missing,
|
||||
SketchContainer sketch_container(ft, num_bins, adapter.NumColumns(), adapter.NumRows(), 0);
|
||||
MetaInfo info;
|
||||
AdapterDeviceSketch(adapter.Value(), num_bins, info, missing, &sketch_container, batch_size);
|
||||
sketch_container.MakeCuts(&batched_cuts);
|
||||
sketch_container.MakeCuts(&batched_cuts, info.IsColumnSplit());
|
||||
return batched_cuts;
|
||||
}
|
||||
|
||||
@ -419,7 +419,7 @@ TEST(HistUtil, AdapterSketchSlidingWindowMemory) {
|
||||
AdapterDeviceSketch(adapter.Value(), num_bins, info, std::numeric_limits<float>::quiet_NaN(),
|
||||
&sketch_container);
|
||||
HistogramCuts cuts;
|
||||
sketch_container.MakeCuts(&cuts);
|
||||
sketch_container.MakeCuts(&cuts, info.IsColumnSplit());
|
||||
size_t bytes_required = detail::RequiredMemory(
|
||||
num_rows, num_columns, num_rows * num_columns, num_bins, false);
|
||||
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05);
|
||||
@ -449,7 +449,7 @@ TEST(HistUtil, AdapterSketchSlidingWindowWeightedMemory) {
|
||||
&sketch_container);
|
||||
|
||||
HistogramCuts cuts;
|
||||
sketch_container.MakeCuts(&cuts);
|
||||
sketch_container.MakeCuts(&cuts, info.IsColumnSplit());
|
||||
ConsoleLogger::Configure({{"verbosity", "0"}});
|
||||
size_t bytes_required = detail::RequiredMemory(
|
||||
num_rows, num_columns, num_rows * num_columns, num_bins, true);
|
||||
@ -482,7 +482,7 @@ void TestCategoricalSketchAdapter(size_t n, size_t num_categories,
|
||||
AdapterDeviceSketch(adapter.Value(), num_bins, info,
|
||||
std::numeric_limits<float>::quiet_NaN(), &container);
|
||||
HistogramCuts cuts;
|
||||
container.MakeCuts(&cuts);
|
||||
container.MakeCuts(&cuts, info.IsColumnSplit());
|
||||
|
||||
thrust::sort(x.begin(), x.end());
|
||||
auto n_uniques = thrust::unique(x.begin(), x.end()) - x.begin();
|
||||
@ -710,7 +710,7 @@ void TestAdapterSketchFromWeights(bool with_group) {
|
||||
&sketch_container);
|
||||
|
||||
common::HistogramCuts cuts;
|
||||
sketch_container.MakeCuts(&cuts);
|
||||
sketch_container.MakeCuts(&cuts, info.IsColumnSplit());
|
||||
|
||||
auto dmat = GetDMatrixFromData(storage.HostVector(), kRows, kCols);
|
||||
if (with_group) {
|
||||
@ -751,7 +751,7 @@ void TestAdapterSketchFromWeights(bool with_group) {
|
||||
SketchContainer sketch_container(ft, kBins, kCols, kRows, 0);
|
||||
AdapterDeviceSketch(adapter.Value(), kBins, info, std::numeric_limits<float>::quiet_NaN(),
|
||||
&sketch_container);
|
||||
sketch_container.MakeCuts(&weighted);
|
||||
sketch_container.MakeCuts(&weighted, info.IsColumnSplit());
|
||||
ValidateCuts(weighted, dmat.get(), kBins);
|
||||
}
|
||||
}
|
||||
|
||||
@ -388,7 +388,7 @@ void TestAllReduceBasic(int32_t n_gpus) {
|
||||
AdapterDeviceSketch(adapter.Value(), n_bins, info,
|
||||
std::numeric_limits<float>::quiet_NaN(),
|
||||
&sketch_distributed);
|
||||
sketch_distributed.AllReduce();
|
||||
sketch_distributed.AllReduce(false);
|
||||
sketch_distributed.Unique();
|
||||
|
||||
ASSERT_EQ(sketch_distributed.ColumnsPtr().size(),
|
||||
@ -425,6 +425,58 @@ TEST(GPUQuantile, MGPUAllReduceBasic) {
|
||||
RunWithInMemoryCommunicator(n_gpus, TestAllReduceBasic, n_gpus);
|
||||
}
|
||||
|
||||
namespace {
|
||||
void TestColumnSplitBasic() {
|
||||
auto const world = collective::GetWorldSize();
|
||||
auto const rank = collective::GetRank();
|
||||
std::size_t constexpr kRows = 1000, kCols = 100, kBins = 64;
|
||||
|
||||
auto m = std::unique_ptr<DMatrix>{[=]() {
|
||||
auto dmat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix();
|
||||
return dmat->SliceCol(world, rank);
|
||||
}()};
|
||||
|
||||
// Generate cuts for distributed environment.
|
||||
auto const device = rank;
|
||||
HistogramCuts distributed_cuts = common::DeviceSketch(device, m.get(), kBins);
|
||||
|
||||
// Generate cuts for single node environment
|
||||
collective::Finalize();
|
||||
CHECK_EQ(collective::GetWorldSize(), 1);
|
||||
HistogramCuts single_node_cuts = common::DeviceSketch(device, m.get(), kBins);
|
||||
|
||||
auto const& sptrs = single_node_cuts.Ptrs();
|
||||
auto const& dptrs = distributed_cuts.Ptrs();
|
||||
auto const& svals = single_node_cuts.Values();
|
||||
auto const& dvals = distributed_cuts.Values();
|
||||
auto const& smins = single_node_cuts.MinValues();
|
||||
auto const& dmins = distributed_cuts.MinValues();
|
||||
|
||||
EXPECT_EQ(sptrs.size(), dptrs.size());
|
||||
for (size_t i = 0; i < sptrs.size(); ++i) {
|
||||
EXPECT_EQ(sptrs[i], dptrs[i]) << "rank: " << rank << ", i: " << i;
|
||||
}
|
||||
|
||||
EXPECT_EQ(svals.size(), dvals.size());
|
||||
for (size_t i = 0; i < svals.size(); ++i) {
|
||||
EXPECT_NEAR(svals[i], dvals[i], 2e-2f) << "rank: " << rank << ", i: " << i;
|
||||
}
|
||||
|
||||
EXPECT_EQ(smins.size(), dmins.size());
|
||||
for (size_t i = 0; i < smins.size(); ++i) {
|
||||
EXPECT_FLOAT_EQ(smins[i], dmins[i]) << "rank: " << rank << ", i: " << i;
|
||||
}
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(GPUQuantile, MGPUColumnSplitBasic) {
|
||||
auto const n_gpus = AllVisibleGPUs();
|
||||
if (n_gpus <= 1) {
|
||||
GTEST_SKIP() << "Skipping MGPUColumnSplitBasic test with # GPUs = " << n_gpus;
|
||||
}
|
||||
RunWithInMemoryCommunicator(n_gpus, TestColumnSplitBasic);
|
||||
}
|
||||
|
||||
namespace {
|
||||
void TestSameOnAllWorkers(std::int32_t n_gpus) {
|
||||
auto world = collective::GetWorldSize();
|
||||
@ -445,7 +497,7 @@ void TestSameOnAllWorkers(std::int32_t n_gpus) {
|
||||
AdapterDeviceSketch(adapter.Value(), n_bins, info,
|
||||
std::numeric_limits<float>::quiet_NaN(),
|
||||
&sketch_distributed);
|
||||
sketch_distributed.AllReduce();
|
||||
sketch_distributed.AllReduce(false);
|
||||
sketch_distributed.Unique();
|
||||
TestQuantileElemRank(device, sketch_distributed.Data(), sketch_distributed.ColumnsPtr(), true);
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user